mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add 'fast' global pool option, remove redundant SEModule from tresnet, normal one is now 'fast'
This commit is contained in:
parent
90a01f47d1
commit
80c9d9cc72
@ -49,6 +49,15 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
|
||||
return x
|
||||
|
||||
|
||||
class FastAdaptiveAvgPool2d(nn.Module):
|
||||
def __init__(self, flatten=False):
|
||||
super(FastAdaptiveAvgPool2d, self).__init__()
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size=1):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
@ -70,12 +79,16 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
|
||||
class SelectAdaptivePool2d(nn.Module):
|
||||
"""Selectable global pooling layer with dynamic input kernel size
|
||||
"""
|
||||
def __init__(self, output_size=1, pool_type='avg', flatten=False):
|
||||
def __init__(self, output_size=1, pool_type='fast', flatten=False):
|
||||
super(SelectAdaptivePool2d, self).__init__()
|
||||
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
||||
self.flatten = flatten
|
||||
if pool_type == '':
|
||||
self.pool = nn.Identity() # pass through
|
||||
elif pool_type == 'fast':
|
||||
assert output_size == 1
|
||||
self.pool = FastAdaptiveAvgPool2d(self.flatten)
|
||||
self.flatten = False
|
||||
elif pool_type == 'avg':
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
elif pool_type == 'avgmax':
|
||||
|
@ -14,7 +14,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead
|
||||
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
|
||||
@ -49,40 +49,6 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
class FastGlobalAvgPool2d(nn.Module):
|
||||
def __init__(self, flatten=False):
|
||||
super(FastGlobalAvgPool2d, self).__init__()
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
if self.flatten:
|
||||
return x.mean((2, 3))
|
||||
else:
|
||||
return x.mean((2, 3), keepdim=True)
|
||||
|
||||
def feat_mult(self):
|
||||
return 1
|
||||
|
||||
|
||||
class FastSEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction_channels, inplace=True):
|
||||
super(FastSEModule, self).__init__()
|
||||
self.avg_pool = FastGlobalAvgPool2d()
|
||||
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||
self.relu = nn.ReLU(inplace=inplace)
|
||||
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
||||
self.activation = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se2 = self.fc1(x_se)
|
||||
x_se2 = self.relu(x_se2)
|
||||
x_se = self.fc2(x_se2)
|
||||
x_se = self.activation(x_se)
|
||||
return x * x_se
|
||||
|
||||
|
||||
def IABN2Float(module: nn.Module) -> nn.Module:
|
||||
"""If `module` is IABN don't use half precision."""
|
||||
if isinstance(module, InplaceAbn):
|
||||
@ -119,8 +85,8 @@ class BasicBlock(nn.Module):
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
reduce_layer_planes = max(planes * self.expansion // 4, 64)
|
||||
self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
|
||||
reduction_chs = max(planes * self.expansion // 4, 64)
|
||||
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
@ -159,8 +125,8 @@ class Bottleneck(nn.Module):
|
||||
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
|
||||
aa_layer(channels=planes, filt_size=3, stride=2))
|
||||
|
||||
reduce_layer_planes = max(planes * self.expansion // 8, 64)
|
||||
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
|
||||
reduction_chs = max(planes * self.expansion // 8, 64)
|
||||
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
|
||||
|
||||
self.conv3 = conv2d_iabn(
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
|
||||
@ -189,7 +155,7 @@ class Bottleneck(nn.Module):
|
||||
|
||||
class TResNet(nn.Module):
|
||||
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
|
||||
global_pool='avg', drop_rate=0.):
|
||||
global_pool='fast', drop_rate=0.):
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
super(TResNet, self).__init__()
|
||||
@ -272,7 +238,7 @@ class TResNet(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes, global_pool='fast'):
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user