184 lines
6.4 KiB
Python
184 lines
6.4 KiB
Python
""" PyTorch selectable adaptive pooling
|
|
Adaptive pooling with the ability to select the type of pooling from:
|
|
* 'avg' - Average pooling
|
|
* 'max' - Max pooling
|
|
* 'avgmax' - Sum of average and max pooling re-scaled by 0.5
|
|
* 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
|
|
|
|
Both a functional and a nn.Module version of the pooling is provided.
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .format import get_spatial_dim, get_channel_dim
|
|
|
|
_int_tuple_2_t = Union[int, Tuple[int, int]]
|
|
|
|
|
|
def adaptive_pool_feat_mult(pool_type='avg'):
|
|
if pool_type.endswith('catavgmax'):
|
|
return 2
|
|
else:
|
|
return 1
|
|
|
|
|
|
def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
|
|
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
|
x_max = F.adaptive_max_pool2d(x, output_size)
|
|
return 0.5 * (x_avg + x_max)
|
|
|
|
|
|
def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
|
|
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
|
x_max = F.adaptive_max_pool2d(x, output_size)
|
|
return torch.cat((x_avg, x_max), 1)
|
|
|
|
|
|
def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
|
|
"""Selectable global pooling function with dynamic input kernel size
|
|
"""
|
|
if pool_type == 'avg':
|
|
x = F.adaptive_avg_pool2d(x, output_size)
|
|
elif pool_type == 'avgmax':
|
|
x = adaptive_avgmax_pool2d(x, output_size)
|
|
elif pool_type == 'catavgmax':
|
|
x = adaptive_catavgmax_pool2d(x, output_size)
|
|
elif pool_type == 'max':
|
|
x = F.adaptive_max_pool2d(x, output_size)
|
|
else:
|
|
assert False, 'Invalid pool type: %s' % pool_type
|
|
return x
|
|
|
|
|
|
class FastAdaptiveAvgPool(nn.Module):
|
|
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
|
|
super(FastAdaptiveAvgPool, self).__init__()
|
|
self.flatten = flatten
|
|
self.dim = get_spatial_dim(input_fmt)
|
|
|
|
def forward(self, x):
|
|
return x.mean(self.dim, keepdim=not self.flatten)
|
|
|
|
|
|
class FastAdaptiveMaxPool(nn.Module):
|
|
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
|
super(FastAdaptiveMaxPool, self).__init__()
|
|
self.flatten = flatten
|
|
self.dim = get_spatial_dim(input_fmt)
|
|
|
|
def forward(self, x):
|
|
return x.amax(self.dim, keepdim=not self.flatten)
|
|
|
|
|
|
class FastAdaptiveAvgMaxPool(nn.Module):
|
|
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
|
super(FastAdaptiveAvgMaxPool, self).__init__()
|
|
self.flatten = flatten
|
|
self.dim = get_spatial_dim(input_fmt)
|
|
|
|
def forward(self, x):
|
|
x_avg = x.mean(self.dim, keepdim=not self.flatten)
|
|
x_max = x.amax(self.dim, keepdim=not self.flatten)
|
|
return 0.5 * x_avg + 0.5 * x_max
|
|
|
|
|
|
class FastAdaptiveCatAvgMaxPool(nn.Module):
|
|
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
|
super(FastAdaptiveCatAvgMaxPool, self).__init__()
|
|
self.flatten = flatten
|
|
self.dim_reduce = get_spatial_dim(input_fmt)
|
|
if flatten:
|
|
self.dim_cat = 1
|
|
else:
|
|
self.dim_cat = get_channel_dim(input_fmt)
|
|
|
|
def forward(self, x):
|
|
x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten)
|
|
x_max = x.amax(self.dim_reduce, keepdim=not self.flatten)
|
|
return torch.cat((x_avg, x_max), self.dim_cat)
|
|
|
|
|
|
class AdaptiveAvgMaxPool2d(nn.Module):
|
|
def __init__(self, output_size: _int_tuple_2_t = 1):
|
|
super(AdaptiveAvgMaxPool2d, self).__init__()
|
|
self.output_size = output_size
|
|
|
|
def forward(self, x):
|
|
return adaptive_avgmax_pool2d(x, self.output_size)
|
|
|
|
|
|
class AdaptiveCatAvgMaxPool2d(nn.Module):
|
|
def __init__(self, output_size: _int_tuple_2_t = 1):
|
|
super(AdaptiveCatAvgMaxPool2d, self).__init__()
|
|
self.output_size = output_size
|
|
|
|
def forward(self, x):
|
|
return adaptive_catavgmax_pool2d(x, self.output_size)
|
|
|
|
|
|
class SelectAdaptivePool2d(nn.Module):
|
|
"""Selectable global pooling layer with dynamic input kernel size
|
|
"""
|
|
def __init__(
|
|
self,
|
|
output_size: _int_tuple_2_t = 1,
|
|
pool_type: str = 'fast',
|
|
flatten: bool = False,
|
|
input_fmt: str = 'NCHW',
|
|
):
|
|
super(SelectAdaptivePool2d, self).__init__()
|
|
assert input_fmt in ('NCHW', 'NHWC')
|
|
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
|
pool_type = pool_type.lower()
|
|
if not pool_type:
|
|
self.pool = nn.Identity() # pass through
|
|
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
|
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
|
|
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
|
|
if pool_type.endswith('catavgmax'):
|
|
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
|
|
elif pool_type.endswith('avgmax'):
|
|
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
|
|
elif pool_type.endswith('max'):
|
|
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
|
|
elif pool_type == 'fast' or pool_type.endswith('avg'):
|
|
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
|
|
else:
|
|
assert False, 'Invalid pool type: %s' % pool_type
|
|
self.flatten = nn.Identity()
|
|
else:
|
|
assert input_fmt == 'NCHW'
|
|
if pool_type == 'avgmax':
|
|
self.pool = AdaptiveAvgMaxPool2d(output_size)
|
|
elif pool_type == 'catavgmax':
|
|
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
|
|
elif pool_type == 'max':
|
|
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
|
elif pool_type == 'avg':
|
|
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
|
else:
|
|
assert False, 'Invalid pool type: %s' % pool_type
|
|
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
|
|
|
def is_identity(self):
|
|
return not self.pool_type
|
|
|
|
def forward(self, x):
|
|
x = self.pool(x)
|
|
x = self.flatten(x)
|
|
return x
|
|
|
|
def feat_mult(self):
|
|
return adaptive_pool_feat_mult(self.pool_type)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' \
|
|
+ 'pool_type=' + self.pool_type \
|
|
+ ', flatten=' + str(self.flatten) + ')'
|
|
|