mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve performance of HaloAttn, change default dim calc. Some cleanup / fixes for byoanet. Rename resnet26ts to tfs to distinguish (extra fc).
This commit is contained in:
parent
a8b65695f1
commit
8449ba210c
@ -52,13 +52,12 @@ model_cfgs = dict(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
self_attn_layer='bottleneck',
|
||||
self_attn_kwargs=dict()
|
||||
@ -66,14 +65,13 @@ model_cfgs = dict(
|
||||
botnet50ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=1, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
act_layer='silu',
|
||||
self_attn_layer='bottleneck',
|
||||
@ -83,13 +81,12 @@ model_cfgs = dict(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
act_layer='silu',
|
||||
attn_layer='eca',
|
||||
@ -107,7 +104,7 @@ model_cfgs = dict(
|
||||
stem_chs=64,
|
||||
stem_type='7x7',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
||||
),
|
||||
@ -115,59 +112,57 @@ model_cfgs = dict(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
|
||||
),
|
||||
halonet50ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(
|
||||
types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
|
||||
self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2)
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=3)
|
||||
),
|
||||
eca_halonext26ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
attn_layer='eca',
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
|
||||
),
|
||||
|
||||
lambda_resnet26t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
self_attn_layer='lambda',
|
||||
self_attn_kwargs=dict(r=9)
|
||||
),
|
||||
@ -185,7 +180,7 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def botnet26t_256(pretrained=False, **kwargs):
|
||||
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
|
||||
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final two stages.
|
||||
"""
|
||||
kwargs.setdefault('img_size', 256)
|
||||
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
|
||||
@ -193,7 +188,7 @@ def botnet26t_256(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def botnet50ts_256(pretrained=False, **kwargs):
|
||||
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage.
|
||||
""" Bottleneck Transformer w/ ResNet50-T backbone, silu act. Bottleneck attn in final two stages.
|
||||
"""
|
||||
kwargs.setdefault('img_size', 256)
|
||||
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
|
||||
@ -201,7 +196,7 @@ def botnet50ts_256(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def eca_botnext26ts_256(pretrained=False, **kwargs):
|
||||
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
|
||||
""" Bottleneck Transformer w/ ResNet26-T backbone, silu act, Bottleneck attn in final two stages.
|
||||
"""
|
||||
kwargs.setdefault('img_size', 256)
|
||||
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
|
||||
@ -210,35 +205,34 @@ def eca_botnext26ts_256(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def halonet_h1(pretrained=False, **kwargs):
|
||||
""" HaloNet-H1. Halo attention in all stages as per the paper.
|
||||
|
||||
This runs very slowly, param count lower than paper --> something is wrong.
|
||||
NOTE: This runs very slowly!
|
||||
"""
|
||||
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def halonet26t(pretrained=False, **kwargs):
|
||||
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
|
||||
""" HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
|
||||
"""
|
||||
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def halonet50ts(pretrained=False, **kwargs):
|
||||
""" HaloNet w/ a ResNet50-t backbone, Hallo attention in final stage
|
||||
""" HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
|
||||
"""
|
||||
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def eca_halonext26ts(pretrained=False, **kwargs):
|
||||
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
|
||||
""" HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
|
||||
"""
|
||||
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def lambda_resnet26t(pretrained=False, **kwargs):
|
||||
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
|
||||
""" Lambda-ResNet-26T. Lambda layers in last two stages.
|
||||
"""
|
||||
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
|
||||
|
@ -107,13 +107,13 @@ default_cfgs = {
|
||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic',
|
||||
min_input_size=(3, 256, 256)),
|
||||
|
||||
'resnet26ts': _cfg(
|
||||
'resnet26tfs': _cfg(
|
||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||
'gcresnet26ts': _cfg(
|
||||
'gcresnet26tfs': _cfg(
|
||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||
'seresnet26ts': _cfg(
|
||||
'seresnet26tfs': _cfg(
|
||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||
'eca_resnet26ts': _cfg(
|
||||
'eca_resnet26tfs': _cfg(
|
||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||
|
||||
'gcresnet50t': _cfg(
|
||||
@ -174,13 +174,13 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
|
||||
|
||||
|
||||
def interleave_blocks(
|
||||
types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
|
||||
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
|
||||
) -> Tuple[ByoBlockCfg]:
|
||||
""" interleave 2 block types in stack
|
||||
"""
|
||||
assert len(types) == 2
|
||||
if isinstance(every, int):
|
||||
every = list(range(0 if first else every, d, every))
|
||||
every = list(range(0 if first else every, d, every + 1))
|
||||
if not every:
|
||||
every = [d - 1]
|
||||
set(every)
|
||||
@ -300,21 +300,6 @@ model_cfgs = dict(
|
||||
block_kwargs=dict(extra_conv=True),
|
||||
),
|
||||
|
||||
# WARN: experimental, may vanish/change
|
||||
geresnet50t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25),
|
||||
ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool=None,
|
||||
attn_layer='ge',
|
||||
attn_kwargs=dict(extent=8, extra_params=True),
|
||||
),
|
||||
|
||||
# A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
|
||||
# and a tiered stem w/ maxpool
|
||||
resnext26ts=ByoModelCfg(
|
||||
@ -327,7 +312,6 @@ model_cfgs = dict(
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
),
|
||||
gcresnext26ts=ByoModelCfg(
|
||||
@ -340,7 +324,6 @@ model_cfgs = dict(
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
attn_layer='gca',
|
||||
),
|
||||
@ -354,8 +337,7 @@ model_cfgs = dict(
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='relu',
|
||||
act_layer='silu',
|
||||
attn_layer='se',
|
||||
),
|
||||
eca_resnext26ts=ByoModelCfg(
|
||||
@ -368,7 +350,6 @@ model_cfgs = dict(
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
attn_layer='eca',
|
||||
),
|
||||
@ -382,7 +363,6 @@ model_cfgs = dict(
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
attn_layer='bat',
|
||||
attn_kwargs=dict(block_size=8)
|
||||
@ -390,7 +370,7 @@ model_cfgs = dict(
|
||||
|
||||
# A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc
|
||||
# and a tiered stem w/ no maxpool
|
||||
resnet26ts=ByoModelCfg(
|
||||
resnet26tfs=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||
@ -403,7 +383,7 @@ model_cfgs = dict(
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
),
|
||||
gcresnet26ts=ByoModelCfg(
|
||||
gcresnet26tfs=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||
@ -417,7 +397,7 @@ model_cfgs = dict(
|
||||
act_layer='silu',
|
||||
attn_layer='gca',
|
||||
),
|
||||
seresnet26ts=ByoModelCfg(
|
||||
seresnet26tfs=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||
@ -431,7 +411,7 @@ model_cfgs = dict(
|
||||
act_layer='silu',
|
||||
attn_layer='se',
|
||||
),
|
||||
eca_resnet26ts=ByoModelCfg(
|
||||
eca_resnet26tfs=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||
@ -455,7 +435,7 @@ model_cfgs = dict(
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool=None,
|
||||
stem_pool='',
|
||||
attn_layer='gca',
|
||||
),
|
||||
|
||||
@ -614,31 +594,31 @@ def bat_resnext26ts(pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet26ts(pretrained=False, **kwargs):
|
||||
def resnet26tfs(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('resnet26ts', pretrained=pretrained, **kwargs)
|
||||
return _create_byobnet('resnet26tfs', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def gcresnet26ts(pretrained=False, **kwargs):
|
||||
def gcresnet26tfs(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('gcresnet26ts', pretrained=pretrained, **kwargs)
|
||||
return _create_byobnet('gcresnet26tfs', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet26ts(pretrained=False, **kwargs):
|
||||
def seresnet26tfs(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('seresnet26ts', pretrained=pretrained, **kwargs)
|
||||
return _create_byobnet('seresnet26tfs', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def eca_resnet26ts(pretrained=False, **kwargs):
|
||||
def eca_resnet26tfs(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('eca_resnet26ts', pretrained=pretrained, **kwargs)
|
||||
return _create_byobnet('eca_resnet26tfs', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -1144,27 +1124,29 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
|
||||
layer_fns = block_kwargs['layers']
|
||||
|
||||
# override attn layer / args with block local config
|
||||
if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None:
|
||||
attn_set = block_cfg.attn_layer is not None
|
||||
if attn_set or block_cfg.attn_kwargs is not None:
|
||||
# override attn layer config
|
||||
if not block_cfg.attn_layer:
|
||||
if attn_set and not block_cfg.attn_layer:
|
||||
# empty string for attn_layer type will disable attn for this block
|
||||
attn_layer = None
|
||||
else:
|
||||
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
|
||||
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
|
||||
attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None
|
||||
attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
|
||||
layer_fns = replace(layer_fns, attn=attn_layer)
|
||||
|
||||
# override self-attn layer / args with block local cfg
|
||||
if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None:
|
||||
self_attn_set = block_cfg.self_attn_layer is not None
|
||||
if self_attn_set or block_cfg.self_attn_kwargs is not None:
|
||||
# override attn layer config
|
||||
if not block_cfg.self_attn_layer:
|
||||
if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
|
||||
# empty string for self_attn_layer type will disable attn for this block
|
||||
self_attn_layer = None
|
||||
else:
|
||||
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
|
||||
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
|
||||
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
|
||||
self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
|
||||
if self_attn_layer is not None else None
|
||||
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
|
||||
|
||||
|
@ -103,19 +103,21 @@ class HaloAttn(nn.Module):
|
||||
- https://arxiv.org/abs/2103.12731
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False):
|
||||
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
assert dim_out % num_heads == 0
|
||||
self.stride = stride
|
||||
self.num_heads = num_heads
|
||||
self.dim_head = dim_head
|
||||
self.dim_qk = num_heads * dim_head
|
||||
self.dim_head = dim_head or dim // num_heads
|
||||
self.dim_qk = num_heads * self.dim_head
|
||||
self.dim_v = dim_out
|
||||
self.block_size = block_size
|
||||
self.halo_size = halo_size
|
||||
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||
self.scale = self.dim_head ** -0.5
|
||||
# stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA)
|
||||
self.stride_tricks = True
|
||||
|
||||
# FIXME not clear if this stride behaviour is what the paper intended
|
||||
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
||||
@ -139,28 +141,43 @@ class HaloAttn(nn.Module):
|
||||
num_h_blocks = H // self.block_size
|
||||
num_w_blocks = W // self.block_size
|
||||
num_blocks = num_h_blocks * num_w_blocks
|
||||
bs_stride = self.block_size // self.stride
|
||||
|
||||
q = self.q(x)
|
||||
q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
|
||||
# q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap
|
||||
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
||||
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
||||
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
||||
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
||||
|
||||
kv = self.kv(x)
|
||||
# FIXME I 'think' this unfold does what I want it to, but I should investigate
|
||||
|
||||
# generate overlapping windows using either stride tricks (as_strided) or unfold
|
||||
if self.stride_tricks:
|
||||
# this is much faster
|
||||
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
||||
kv = kv.as_strided((
|
||||
B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
||||
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
||||
else:
|
||||
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
||||
|
||||
kv = kv.reshape(
|
||||
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
||||
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
||||
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
||||
|
||||
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
||||
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
||||
|
||||
attn_out = attn_logits.softmax(dim=-1)
|
||||
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
||||
attn_out = F.fold(
|
||||
attn_out.reshape(B, -1, num_blocks),
|
||||
(H // self.stride, W // self.stride),
|
||||
kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
|
||||
|
||||
# F.fold can be replaced by reshape + permute, slightly faster
|
||||
# attn_out = F.fold(
|
||||
# attn_out.reshape(B, -1, num_blocks),
|
||||
# (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride)
|
||||
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
||||
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
|
||||
# B, dim_out, H // stride, W // stride
|
||||
return attn_out
|
||||
|
Loading…
x
Reference in New Issue
Block a user