mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update EfficientNet feature extraction for EfficientDet. Add needed MaxPoolSame as well.
This commit is contained in:
parent
e01ccb88ce
commit
1a8f5900ab
@ -326,7 +326,6 @@ class EfficientNet(nn.Module):
|
|||||||
# Stem
|
# Stem
|
||||||
if not fix_stem:
|
if not fix_stem:
|
||||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||||
print(stem_size)
|
|
||||||
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
@ -393,7 +392,7 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
and object detection models.
|
and object detection models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||||
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||||
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||||
@ -404,6 +403,7 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
num_stages = max(out_indices) + 1
|
num_stages = max(out_indices) + 1
|
||||||
|
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
|
self.feature_location = feature_location
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self._in_chs = in_chans
|
self._in_chs = in_chans
|
||||||
|
|
||||||
@ -420,18 +420,23 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
self._feature_info = builder.features # builder provides info about feature channels for each block
|
||||||
|
self._stage_to_feature_idx = {
|
||||||
|
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
||||||
self._in_chs = builder.in_chs
|
self._in_chs = builder.in_chs
|
||||||
|
|
||||||
efficientnet_init_weights(self)
|
efficientnet_init_weights(self)
|
||||||
if _DEBUG:
|
if _DEBUG:
|
||||||
for k, v in self.feature_info.items():
|
for k, v in self._feature_info.items():
|
||||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||||
|
|
||||||
# Register feature extraction hooks with FeatureHooks helper
|
# Register feature extraction hooks with FeatureHooks helper
|
||||||
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
|
self.feature_hooks = None
|
||||||
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
|
if feature_location != 'bottleneck':
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
hooks = [dict(
|
||||||
|
name=self._feature_info[idx]['module'],
|
||||||
|
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
||||||
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
def feature_channels(self, idx=None):
|
def feature_channels(self, idx=None):
|
||||||
""" Feature Channel Shortcut
|
""" Feature Channel Shortcut
|
||||||
@ -439,15 +444,32 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
return feature channel count for that feature block index (independent of out_indices setting).
|
return feature channel count for that feature block index (independent of out_indices setting).
|
||||||
"""
|
"""
|
||||||
if isinstance(idx, int):
|
if isinstance(idx, int):
|
||||||
return self.feature_info[idx]['num_chs']
|
return self._feature_info[idx]['num_chs']
|
||||||
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||||
|
|
||||||
|
def feature_info(self, idx=None):
|
||||||
|
""" Feature Channel Shortcut
|
||||||
|
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||||
|
return feature channel count for that feature block index (independent of out_indices setting).
|
||||||
|
"""
|
||||||
|
if isinstance(idx, int):
|
||||||
|
return self._feature_info[idx]
|
||||||
|
return [self._feature_info[i] for i in self.out_indices]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = self.act1(x)
|
x = self.act1(x)
|
||||||
self.blocks(x)
|
if self.feature_hooks is None:
|
||||||
return self.feature_hooks.get_output(x.device)
|
features = []
|
||||||
|
for i, b in enumerate(self.blocks):
|
||||||
|
x = b(x)
|
||||||
|
if i in self._stage_to_feature_idx:
|
||||||
|
features.append(x)
|
||||||
|
return features
|
||||||
|
else:
|
||||||
|
self.blocks(x)
|
||||||
|
return self.feature_hooks.get_output(x.device)
|
||||||
|
|
||||||
|
|
||||||
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
||||||
|
@ -120,11 +120,13 @@ class ConvBnAct(nn.Module):
|
|||||||
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
def feature_module(self, location):
|
def feature_info(self, location):
|
||||||
return 'act1'
|
if location == 'expansion' or location == 'depthwise':
|
||||||
|
# no expansion or depthwise this block, use act after conv
|
||||||
def feature_channels(self, location):
|
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
|
||||||
return self.conv.out_channels
|
else: # location == 'bottleneck'
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
@ -165,12 +167,15 @@ class DepthwiseSeparableConv(nn.Module):
|
|||||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||||
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
||||||
|
|
||||||
def feature_module(self, location):
|
def feature_info(self, location):
|
||||||
# no expansion in this block, pre pw only feature extraction point
|
if location == 'expansion':
|
||||||
return 'conv_pw'
|
# no expansion in this block, use depthwise, before SE
|
||||||
|
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
|
||||||
def feature_channels(self, location):
|
elif location == 'depthwise': # after SE
|
||||||
return self.conv_pw.in_channels
|
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||||
|
else: # location == 'bottleneck'
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
@ -232,16 +237,14 @@ class InvertedResidual(nn.Module):
|
|||||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
|
||||||
def feature_module(self, location):
|
def feature_info(self, location):
|
||||||
if location == 'post_exp':
|
if location == 'expansion':
|
||||||
return 'act1'
|
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
|
||||||
return 'conv_pwl'
|
elif location == 'depthwise': # after SE
|
||||||
|
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||||
def feature_channels(self, location):
|
else: # location == 'bottleneck'
|
||||||
if location == 'post_exp':
|
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||||
return self.conv_pw.out_channels
|
return info
|
||||||
# location == 'pre_pw'
|
|
||||||
return self.conv_pwl.in_channels
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
@ -359,16 +362,15 @@ class EdgeResidual(nn.Module):
|
|||||||
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
|
||||||
def feature_module(self, location):
|
def feature_info(self, location):
|
||||||
if location == 'post_exp':
|
if location == 'expansion':
|
||||||
return 'act1'
|
info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels)
|
||||||
return 'conv_pwl'
|
elif location == 'depthwise':
|
||||||
|
# there is no depthwise, take after SE, before PWL
|
||||||
def feature_channels(self, location):
|
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||||
if location == 'post_exp':
|
else: # location == 'bottleneck'
|
||||||
return self.conv_exp.out_channels
|
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||||
# location == 'pre_pw'
|
return info
|
||||||
return self.conv_pwl.in_channels
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
@ -218,7 +218,7 @@ class EfficientNetBuilder:
|
|||||||
self.norm_kwargs = norm_kwargs
|
self.norm_kwargs = norm_kwargs
|
||||||
self.drop_path_rate = drop_path_rate
|
self.drop_path_rate = drop_path_rate
|
||||||
self.feature_location = feature_location
|
self.feature_location = feature_location
|
||||||
assert feature_location in ('pre_pwl', 'post_exp', '')
|
assert feature_location in ('bottleneck', 'depthwise', 'expansion', '')
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# state updated during build, consumed by model
|
# state updated during build, consumed by model
|
||||||
@ -313,20 +313,21 @@ class EfficientNetBuilder:
|
|||||||
block_args['stride'] = 1
|
block_args['stride'] = 1
|
||||||
|
|
||||||
do_extract = False
|
do_extract = False
|
||||||
if self.feature_location == 'pre_pwl':
|
if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
|
||||||
if last_block:
|
if last_block:
|
||||||
next_stage_idx = stage_idx + 1
|
next_stage_idx = stage_idx + 1
|
||||||
if next_stage_idx >= len(model_block_args):
|
if next_stage_idx >= len(model_block_args):
|
||||||
do_extract = True
|
do_extract = True
|
||||||
else:
|
else:
|
||||||
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
|
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
|
||||||
elif self.feature_location == 'post_exp':
|
elif self.feature_location == 'expansion':
|
||||||
if block_args['stride'] > 1 or (last_stack and last_block) :
|
if block_args['stride'] > 1 or (last_stack and last_block):
|
||||||
do_extract = True
|
do_extract = True
|
||||||
if do_extract:
|
if do_extract:
|
||||||
extract_features = self.feature_location
|
extract_features = self.feature_location
|
||||||
|
|
||||||
next_dilation = current_dilation
|
next_dilation = current_dilation
|
||||||
|
next_output_stride = current_stride
|
||||||
if block_args['stride'] > 1:
|
if block_args['stride'] > 1:
|
||||||
next_output_stride = current_stride * block_args['stride']
|
next_output_stride = current_stride * block_args['stride']
|
||||||
if next_output_stride > self.output_stride:
|
if next_output_stride > self.output_stride:
|
||||||
@ -347,14 +348,13 @@ class EfficientNetBuilder:
|
|||||||
|
|
||||||
# stash feature module name and channel info for model feature extraction
|
# stash feature module name and channel info for model feature extraction
|
||||||
if extract_features:
|
if extract_features:
|
||||||
feature_module = block.feature_module(extract_features)
|
feature_info = block.feature_info(extract_features)
|
||||||
if feature_module:
|
if feature_info['module']:
|
||||||
feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module
|
feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module']
|
||||||
feature_channels = block.feature_channels(extract_features)
|
feature_info['stage_idx'] = stage_idx
|
||||||
self.features[feature_idx] = dict(
|
feature_info['block_idx'] = block_idx
|
||||||
name=feature_module,
|
feature_info['reduction'] = current_stride
|
||||||
num_chs=feature_channels
|
self.features[feature_idx] = feature_info
|
||||||
)
|
|
||||||
feature_idx += 1
|
feature_idx += 1
|
||||||
|
|
||||||
total_block_idx += 1 # incr global block idx (across all stacks)
|
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from .padding import get_padding
|
from .padding import get_padding
|
||||||
from .avg_pool2d_same import AvgPool2dSame
|
from .pool2d_same import AvgPool2dSame
|
||||||
from .conv2d_same import Conv2dSame
|
from .conv2d_same import Conv2dSame
|
||||||
from .conv_bn_act import ConvBnAct
|
from .conv_bn_act import ConvBnAct
|
||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
|
from .pool2d_same import create_pool2d
|
||||||
from .create_conv2d import create_conv2d
|
from .create_conv2d import create_conv2d
|
||||||
from .create_attn import create_attn
|
from .create_attn import create_attn
|
||||||
from .selective_kernel import SelectiveKernelConv
|
from .selective_kernel import SelectiveKernelConv
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
""" AvgPool2d w/ Same Padding
|
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing import List
|
|
||||||
import math
|
|
||||||
|
|
||||||
from .helpers import tup_pair
|
|
||||||
from .padding import pad_same
|
|
||||||
|
|
||||||
|
|
||||||
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
|
||||||
ceil_mode: bool = False, count_include_pad: bool = True):
|
|
||||||
x = pad_same(x, kernel_size, stride)
|
|
||||||
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
|
||||||
|
|
||||||
|
|
||||||
class AvgPool2dSame(nn.AvgPool2d):
|
|
||||||
""" Tensorflow like 'SAME' wrapper for 2D average pooling
|
|
||||||
"""
|
|
||||||
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
|
||||||
kernel_size = tup_pair(kernel_size)
|
|
||||||
stride = tup_pair(stride)
|
|
||||||
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return avg_pool2d_same(
|
|
||||||
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
|
@ -14,7 +14,8 @@ from torch import nn as nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .helpers import tup_pair
|
from .helpers import tup_pair
|
||||||
from .conv2d_same import get_padding_value, conv2d_same
|
from .conv2d_same import conv2d_same
|
||||||
|
from timm.models.layers.padding import get_padding_value
|
||||||
|
|
||||||
|
|
||||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||||
|
@ -5,10 +5,10 @@ Hacked together by Ross Wightman
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Union, List, Tuple, Optional, Callable
|
from typing import Tuple, Optional
|
||||||
import math
|
|
||||||
|
|
||||||
from .padding import get_padding, pad_same, is_static_pad
|
from timm.models.layers.padding import get_padding_value
|
||||||
|
from .padding import pad_same
|
||||||
|
|
||||||
|
|
||||||
def conv2d_same(
|
def conv2d_same(
|
||||||
@ -31,29 +31,6 @@ class Conv2dSame(nn.Conv2d):
|
|||||||
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
|
||||||
dynamic = False
|
|
||||||
if isinstance(padding, str):
|
|
||||||
# for any string padding, the padding will be calculated for you, one of three ways
|
|
||||||
padding = padding.lower()
|
|
||||||
if padding == 'same':
|
|
||||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
|
||||||
if is_static_pad(kernel_size, **kwargs):
|
|
||||||
# static case, no extra overhead
|
|
||||||
padding = get_padding(kernel_size, **kwargs)
|
|
||||||
else:
|
|
||||||
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
|
||||||
padding = 0
|
|
||||||
dynamic = True
|
|
||||||
elif padding == 'valid':
|
|
||||||
# 'VALID' padding, same as padding=0
|
|
||||||
padding = 0
|
|
||||||
else:
|
|
||||||
# Default to PyTorch style 'same'-ish symmetric padding
|
|
||||||
padding = get_padding(kernel_size, **kwargs)
|
|
||||||
return padding, dynamic
|
|
||||||
|
|
||||||
|
|
||||||
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||||
padding = kwargs.pop('padding', '')
|
padding = kwargs.pop('padding', '')
|
||||||
kwargs.setdefault('bias', False)
|
kwargs.setdefault('bias', False)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -25,9 +25,32 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
|||||||
|
|
||||||
|
|
||||||
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
||||||
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
|
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
|
||||||
ih, iw = x.size()[-2:]
|
ih, iw = x.size()[-2:]
|
||||||
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
||||||
if pad_h > 0 or pad_w > 0:
|
if pad_h > 0 or pad_w > 0:
|
||||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
||||||
|
dynamic = False
|
||||||
|
if isinstance(padding, str):
|
||||||
|
# for any string padding, the padding will be calculated for you, one of three ways
|
||||||
|
padding = padding.lower()
|
||||||
|
if padding == 'same':
|
||||||
|
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||||
|
if is_static_pad(kernel_size, **kwargs):
|
||||||
|
# static case, no extra overhead
|
||||||
|
padding = get_padding(kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
||||||
|
padding = 0
|
||||||
|
dynamic = True
|
||||||
|
elif padding == 'valid':
|
||||||
|
# 'VALID' padding, same as padding=0
|
||||||
|
padding = 0
|
||||||
|
else:
|
||||||
|
# Default to PyTorch style 'same'-ish symmetric padding
|
||||||
|
padding = get_padding(kernel_size, **kwargs)
|
||||||
|
return padding, dynamic
|
||||||
|
71
timm/models/layers/pool2d_same.py
Normal file
71
timm/models/layers/pool2d_same.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
""" AvgPool2d w/ Same Padding
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Union, List, Tuple, Optional
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .helpers import tup_pair
|
||||||
|
from .padding import pad_same, get_padding_value
|
||||||
|
|
||||||
|
|
||||||
|
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||||
|
ceil_mode: bool = False, count_include_pad: bool = True):
|
||||||
|
# FIXME how to deal with count_include_pad vs not for external padding?
|
||||||
|
x = pad_same(x, kernel_size, stride)
|
||||||
|
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool2dSame(nn.AvgPool2d):
|
||||||
|
""" Tensorflow like 'SAME' wrapper for 2D average pooling
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||||
|
kernel_size = tup_pair(kernel_size)
|
||||||
|
stride = tup_pair(stride)
|
||||||
|
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return avg_pool2d_same(
|
||||||
|
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
|
def max_pool2d_same(
|
||||||
|
x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||||
|
dilation: List[int] = (1, 1), ceil_mode: bool = False):
|
||||||
|
x = pad_same(x, kernel_size, stride, value=-float('inf'))
|
||||||
|
return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool2dSame(nn.MaxPool2d):
|
||||||
|
""" Tensorflow like 'SAME' wrapper for 2D max pooling
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
|
||||||
|
kernel_size = tup_pair(kernel_size)
|
||||||
|
stride = tup_pair(stride)
|
||||||
|
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
|
||||||
|
stride = stride or kernel_size
|
||||||
|
padding = kwargs.pop('padding', '')
|
||||||
|
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
|
||||||
|
if is_dynamic:
|
||||||
|
if pool_type == 'avg':
|
||||||
|
return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
|
||||||
|
elif pool_type == 'max':
|
||||||
|
return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
|
||||||
|
else:
|
||||||
|
assert False, f'Unsupported pool type {pool_type}'
|
||||||
|
else:
|
||||||
|
if pool_type == 'avg':
|
||||||
|
return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
|
||||||
|
elif pool_type == 'max':
|
||||||
|
return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
|
||||||
|
else:
|
||||||
|
assert False, f'Unsupported pool type {pool_type}'
|
Loading…
x
Reference in New Issue
Block a user