mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
EcaModule(CamelCase)
CamelCased EcaModule. Renamed all instances of ecalayer to EcaModule. eca_module.py->EcaModule.py
This commit is contained in:
parent
d04ff95eda
commit
db91ba053b
@ -36,7 +36,7 @@ from torch import nn
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
class eca_layer(nn.Module):
|
class EcaModule(nn.Module):
|
||||||
"""Constructs a ECA module.
|
"""Constructs a ECA module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -44,7 +44,7 @@ class eca_layer(nn.Module):
|
|||||||
k_size: Adaptive selection of kernel size
|
k_size: Adaptive selection of kernel size
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel, k_size=3):
|
def __init__(self, channel, k_size=3):
|
||||||
super(eca_layer, self).__init__()
|
super(EcaModule, self).__init__()
|
||||||
assert k_size % 2 == 1
|
assert k_size % 2 == 1
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
|
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
|
||||||
@ -79,7 +79,7 @@ class eca_layer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ceca_layer(nn.Module):
|
class CecaModule(nn.Module):
|
||||||
"""Constructs a circular ECA module.
|
"""Constructs a circular ECA module.
|
||||||
the primary difference is that the conv uses a circular padding rather than zero padding.
|
the primary difference is that the conv uses a circular padding rather than zero padding.
|
||||||
This is because unlike images, the channels themselves do not have inherent ordering nor
|
This is because unlike images, the channels themselves do not have inherent ordering nor
|
||||||
@ -94,7 +94,7 @@ class ceca_layer(nn.Module):
|
|||||||
k_size: Adaptive selection of kernel size
|
k_size: Adaptive selection of kernel size
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel, k_size=3):
|
def __init__(self, channel, k_size=3):
|
||||||
super(ceca_layer, self).__init__()
|
super(CecaModule, self).__init__()
|
||||||
assert k_size % 2 == 1
|
assert k_size % 2 == 1
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
#pytorch circular padding mode is bugged as of pytorch 1.4
|
#pytorch circular padding mode is bugged as of pytorch 1.4
|
@ -14,7 +14,7 @@ import torch.nn.functional as F
|
|||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
from .eca_module import eca_layer
|
from .EcaModule import EcaModule
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ class BasicBlock(nn.Module):
|
|||||||
self.bn2 = norm_layer(outplanes)
|
self.bn2 = norm_layer(outplanes)
|
||||||
|
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||||
self.eca = eca_layer(outplanes) if use_eca else None
|
self.eca = EcaModule(outplanes) if use_eca else None
|
||||||
|
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
@ -212,7 +212,7 @@ class Bottleneck(nn.Module):
|
|||||||
self.bn3 = norm_layer(outplanes)
|
self.bn3 = norm_layer(outplanes)
|
||||||
|
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||||
self.eca = eca_layer(outplanes) if use_eca else None
|
self.eca = Eca_Module(outplanes) if use_eca else None
|
||||||
|
|
||||||
self.act3 = act_layer(inplace=True)
|
self.act3 = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
|
Loading…
x
Reference in New Issue
Block a user