pytorch-image-models/timm/models/layers/create_attn.py

38 lines
1.2 KiB
Python
Raw Normal View History

""" Select AttentionFactory Method
Hacked together by Ross Wightman
"""
import torch
from .se import SEModule, EffectiveSEModule
from .eca import EcaModule, CecaModule
2020-02-11 08:23:09 +08:00
from .cbam import CbamModule, LightCbamModule
def create_attn(attn_type, channels, **kwargs):
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'ese':
module_cls = EffectiveSEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'ceca':
module_cls = CecaModule
2020-02-11 08:23:09 +08:00
elif attn_type == 'cbam':
module_cls = CbamModule
elif attn_type == 'lcbam':
module_cls = LightCbamModule
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):
if attn_type:
module_cls = SEModule
else:
module_cls = attn_type
if module_cls is not None:
return module_cls(channels, **kwargs)
return None