2021-04-12 09:38:02 -07:00
|
|
|
from .bottleneck_attn import BottleneckAttn
|
|
|
|
from .halo_attn import HaloAttn
|
2021-05-14 16:48:58 -07:00
|
|
|
from .involution import Involution
|
2021-04-12 09:38:02 -07:00
|
|
|
from .lambda_layer import LambdaLayer
|
2021-04-29 21:08:37 -07:00
|
|
|
from .swin_attn import WindowAttention
|
2021-04-12 09:38:02 -07:00
|
|
|
|
|
|
|
|
|
|
|
def get_self_attn(attn_type):
|
|
|
|
if attn_type == 'bottleneck':
|
|
|
|
return BottleneckAttn
|
|
|
|
elif attn_type == 'halo':
|
|
|
|
return HaloAttn
|
|
|
|
elif attn_type == 'lambda':
|
|
|
|
return LambdaLayer
|
2021-04-29 21:08:37 -07:00
|
|
|
elif attn_type == 'swin':
|
|
|
|
return WindowAttention
|
2021-05-14 16:48:58 -07:00
|
|
|
elif attn_type == 'involution':
|
|
|
|
return Involution
|
2021-04-29 21:08:37 -07:00
|
|
|
else:
|
|
|
|
assert False, f"Unknown attn type ({attn_type})"
|
2021-04-12 09:38:02 -07:00
|
|
|
|
|
|
|
|
|
|
|
def create_self_attn(attn_type, dim, stride=1, **kwargs):
|
|
|
|
attn_fn = get_self_attn(attn_type)
|
|
|
|
return attn_fn(dim, stride=stride, **kwargs)
|