mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve performance of HaloAttn, change default dim calc. Some cleanup / fixes for byoanet. Rename resnet26ts to tfs to distinguish (extra fc).
This commit is contained in:
parent
a8b65695f1
commit
8449ba210c
@ -52,13 +52,12 @@ model_cfgs = dict(
|
|||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
fixed_input_size=True,
|
fixed_input_size=True,
|
||||||
self_attn_layer='bottleneck',
|
self_attn_layer='bottleneck',
|
||||||
self_attn_kwargs=dict()
|
self_attn_kwargs=dict()
|
||||||
@ -66,14 +65,13 @@ model_cfgs = dict(
|
|||||||
botnet50ts=ByoModelCfg(
|
botnet50ts=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25),
|
||||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=1, gs=0, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='',
|
stem_pool='',
|
||||||
num_features=0,
|
|
||||||
fixed_input_size=True,
|
fixed_input_size=True,
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
self_attn_layer='bottleneck',
|
self_attn_layer='bottleneck',
|
||||||
@ -83,13 +81,12 @@ model_cfgs = dict(
|
|||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
|
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
|
||||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
|
||||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
|
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
fixed_input_size=True,
|
fixed_input_size=True,
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
attn_layer='eca',
|
attn_layer='eca',
|
||||||
@ -107,7 +104,7 @@ model_cfgs = dict(
|
|||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='7x7',
|
stem_type='7x7',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
self_attn_layer='halo',
|
self_attn_layer='halo',
|
||||||
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
||||||
),
|
),
|
||||||
@ -115,59 +112,57 @@ model_cfgs = dict(
|
|||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
self_attn_layer='halo',
|
self_attn_layer='halo',
|
||||||
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
|
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
|
||||||
),
|
),
|
||||||
halonet50ts=ByoModelCfg(
|
halonet50ts=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
interleave_blocks(
|
||||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
|
||||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
|
||||||
|
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
self_attn_layer='halo',
|
self_attn_layer='halo',
|
||||||
self_attn_kwargs=dict(block_size=8, halo_size=2)
|
self_attn_kwargs=dict(block_size=8, halo_size=3)
|
||||||
),
|
),
|
||||||
eca_halonext26ts=ByoModelCfg(
|
eca_halonext26ts=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
||||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
|
||||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
attn_layer='eca',
|
attn_layer='eca',
|
||||||
self_attn_layer='halo',
|
self_attn_layer='halo',
|
||||||
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
|
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
|
||||||
),
|
),
|
||||||
|
|
||||||
lambda_resnet26t=ByoModelCfg(
|
lambda_resnet26t=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
self_attn_layer='lambda',
|
self_attn_layer='lambda',
|
||||||
self_attn_kwargs=dict(r=9)
|
self_attn_kwargs=dict(r=9)
|
||||||
),
|
),
|
||||||
@ -185,7 +180,7 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def botnet26t_256(pretrained=False, **kwargs):
|
def botnet26t_256(pretrained=False, **kwargs):
|
||||||
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
|
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final two stages.
|
||||||
"""
|
"""
|
||||||
kwargs.setdefault('img_size', 256)
|
kwargs.setdefault('img_size', 256)
|
||||||
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
|
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
|
||||||
@ -193,7 +188,7 @@ def botnet26t_256(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def botnet50ts_256(pretrained=False, **kwargs):
|
def botnet50ts_256(pretrained=False, **kwargs):
|
||||||
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage.
|
""" Bottleneck Transformer w/ ResNet50-T backbone, silu act. Bottleneck attn in final two stages.
|
||||||
"""
|
"""
|
||||||
kwargs.setdefault('img_size', 256)
|
kwargs.setdefault('img_size', 256)
|
||||||
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
|
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
|
||||||
@ -201,7 +196,7 @@ def botnet50ts_256(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def eca_botnext26ts_256(pretrained=False, **kwargs):
|
def eca_botnext26ts_256(pretrained=False, **kwargs):
|
||||||
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
|
""" Bottleneck Transformer w/ ResNet26-T backbone, silu act, Bottleneck attn in final two stages.
|
||||||
"""
|
"""
|
||||||
kwargs.setdefault('img_size', 256)
|
kwargs.setdefault('img_size', 256)
|
||||||
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
|
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
|
||||||
@ -210,35 +205,34 @@ def eca_botnext26ts_256(pretrained=False, **kwargs):
|
|||||||
@register_model
|
@register_model
|
||||||
def halonet_h1(pretrained=False, **kwargs):
|
def halonet_h1(pretrained=False, **kwargs):
|
||||||
""" HaloNet-H1. Halo attention in all stages as per the paper.
|
""" HaloNet-H1. Halo attention in all stages as per the paper.
|
||||||
|
NOTE: This runs very slowly!
|
||||||
This runs very slowly, param count lower than paper --> something is wrong.
|
|
||||||
"""
|
"""
|
||||||
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
|
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def halonet26t(pretrained=False, **kwargs):
|
def halonet26t(pretrained=False, **kwargs):
|
||||||
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
|
""" HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
|
||||||
"""
|
"""
|
||||||
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
|
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def halonet50ts(pretrained=False, **kwargs):
|
def halonet50ts(pretrained=False, **kwargs):
|
||||||
""" HaloNet w/ a ResNet50-t backbone, Hallo attention in final stage
|
""" HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
|
||||||
"""
|
"""
|
||||||
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
|
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def eca_halonext26ts(pretrained=False, **kwargs):
|
def eca_halonext26ts(pretrained=False, **kwargs):
|
||||||
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
|
""" HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
|
||||||
"""
|
"""
|
||||||
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
|
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def lambda_resnet26t(pretrained=False, **kwargs):
|
def lambda_resnet26t(pretrained=False, **kwargs):
|
||||||
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
|
""" Lambda-ResNet-26T. Lambda layers in last two stages.
|
||||||
"""
|
"""
|
||||||
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
|
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
|
||||||
|
@ -107,13 +107,13 @@ default_cfgs = {
|
|||||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic',
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic',
|
||||||
min_input_size=(3, 256, 256)),
|
min_input_size=(3, 256, 256)),
|
||||||
|
|
||||||
'resnet26ts': _cfg(
|
'resnet26tfs': _cfg(
|
||||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||||
'gcresnet26ts': _cfg(
|
'gcresnet26tfs': _cfg(
|
||||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||||
'seresnet26ts': _cfg(
|
'seresnet26tfs': _cfg(
|
||||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||||
'eca_resnet26ts': _cfg(
|
'eca_resnet26tfs': _cfg(
|
||||||
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
|
||||||
|
|
||||||
'gcresnet50t': _cfg(
|
'gcresnet50t': _cfg(
|
||||||
@ -174,13 +174,13 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
|
|||||||
|
|
||||||
|
|
||||||
def interleave_blocks(
|
def interleave_blocks(
|
||||||
types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
|
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
|
||||||
) -> Tuple[ByoBlockCfg]:
|
) -> Tuple[ByoBlockCfg]:
|
||||||
""" interleave 2 block types in stack
|
""" interleave 2 block types in stack
|
||||||
"""
|
"""
|
||||||
assert len(types) == 2
|
assert len(types) == 2
|
||||||
if isinstance(every, int):
|
if isinstance(every, int):
|
||||||
every = list(range(0 if first else every, d, every))
|
every = list(range(0 if first else every, d, every + 1))
|
||||||
if not every:
|
if not every:
|
||||||
every = [d - 1]
|
every = [d - 1]
|
||||||
set(every)
|
set(every)
|
||||||
@ -300,21 +300,6 @@ model_cfgs = dict(
|
|||||||
block_kwargs=dict(extra_conv=True),
|
block_kwargs=dict(extra_conv=True),
|
||||||
),
|
),
|
||||||
|
|
||||||
# WARN: experimental, may vanish/change
|
|
||||||
geresnet50t=ByoModelCfg(
|
|
||||||
blocks=(
|
|
||||||
ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25),
|
|
||||||
ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25),
|
|
||||||
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
|
|
||||||
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
|
|
||||||
),
|
|
||||||
stem_chs=64,
|
|
||||||
stem_type='tiered',
|
|
||||||
stem_pool=None,
|
|
||||||
attn_layer='ge',
|
|
||||||
attn_kwargs=dict(extent=8, extra_params=True),
|
|
||||||
),
|
|
||||||
|
|
||||||
# A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
|
# A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
|
||||||
# and a tiered stem w/ maxpool
|
# and a tiered stem w/ maxpool
|
||||||
resnext26ts=ByoModelCfg(
|
resnext26ts=ByoModelCfg(
|
||||||
@ -327,7 +312,6 @@ model_cfgs = dict(
|
|||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
),
|
),
|
||||||
gcresnext26ts=ByoModelCfg(
|
gcresnext26ts=ByoModelCfg(
|
||||||
@ -340,7 +324,6 @@ model_cfgs = dict(
|
|||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
attn_layer='gca',
|
attn_layer='gca',
|
||||||
),
|
),
|
||||||
@ -354,8 +337,7 @@ model_cfgs = dict(
|
|||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
act_layer='silu',
|
||||||
act_layer='relu',
|
|
||||||
attn_layer='se',
|
attn_layer='se',
|
||||||
),
|
),
|
||||||
eca_resnext26ts=ByoModelCfg(
|
eca_resnext26ts=ByoModelCfg(
|
||||||
@ -368,7 +350,6 @@ model_cfgs = dict(
|
|||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
attn_layer='eca',
|
attn_layer='eca',
|
||||||
),
|
),
|
||||||
@ -382,7 +363,6 @@ model_cfgs = dict(
|
|||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
attn_layer='bat',
|
attn_layer='bat',
|
||||||
attn_kwargs=dict(block_size=8)
|
attn_kwargs=dict(block_size=8)
|
||||||
@ -390,7 +370,7 @@ model_cfgs = dict(
|
|||||||
|
|
||||||
# A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc
|
# A series of ResNet-26 models w/ one of none, GC, SE, ECA attn, no groups, SiLU act, 1280 feat fc
|
||||||
# and a tiered stem w/ no maxpool
|
# and a tiered stem w/ no maxpool
|
||||||
resnet26ts=ByoModelCfg(
|
resnet26tfs=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||||
@ -403,7 +383,7 @@ model_cfgs = dict(
|
|||||||
num_features=0,
|
num_features=0,
|
||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
),
|
),
|
||||||
gcresnet26ts=ByoModelCfg(
|
gcresnet26tfs=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||||
@ -417,7 +397,7 @@ model_cfgs = dict(
|
|||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
attn_layer='gca',
|
attn_layer='gca',
|
||||||
),
|
),
|
||||||
seresnet26ts=ByoModelCfg(
|
seresnet26tfs=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||||
@ -431,7 +411,7 @@ model_cfgs = dict(
|
|||||||
act_layer='silu',
|
act_layer='silu',
|
||||||
attn_layer='se',
|
attn_layer='se',
|
||||||
),
|
),
|
||||||
eca_resnet26ts=ByoModelCfg(
|
eca_resnet26tfs=ByoModelCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
||||||
@ -455,7 +435,7 @@ model_cfgs = dict(
|
|||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
stem_type='tiered',
|
stem_type='tiered',
|
||||||
stem_pool=None,
|
stem_pool='',
|
||||||
attn_layer='gca',
|
attn_layer='gca',
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -614,31 +594,31 @@ def bat_resnext26ts(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet26ts(pretrained=False, **kwargs):
|
def resnet26tfs(pretrained=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('resnet26ts', pretrained=pretrained, **kwargs)
|
return _create_byobnet('resnet26tfs', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def gcresnet26ts(pretrained=False, **kwargs):
|
def gcresnet26tfs(pretrained=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('gcresnet26ts', pretrained=pretrained, **kwargs)
|
return _create_byobnet('gcresnet26tfs', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def seresnet26ts(pretrained=False, **kwargs):
|
def seresnet26tfs(pretrained=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('seresnet26ts', pretrained=pretrained, **kwargs)
|
return _create_byobnet('seresnet26tfs', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def eca_resnet26ts(pretrained=False, **kwargs):
|
def eca_resnet26tfs(pretrained=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('eca_resnet26ts', pretrained=pretrained, **kwargs)
|
return _create_byobnet('eca_resnet26tfs', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -1144,27 +1124,29 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
|
|||||||
layer_fns = block_kwargs['layers']
|
layer_fns = block_kwargs['layers']
|
||||||
|
|
||||||
# override attn layer / args with block local config
|
# override attn layer / args with block local config
|
||||||
if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None:
|
attn_set = block_cfg.attn_layer is not None
|
||||||
|
if attn_set or block_cfg.attn_kwargs is not None:
|
||||||
# override attn layer config
|
# override attn layer config
|
||||||
if not block_cfg.attn_layer:
|
if attn_set and not block_cfg.attn_layer:
|
||||||
# empty string for attn_layer type will disable attn for this block
|
# empty string for attn_layer type will disable attn for this block
|
||||||
attn_layer = None
|
attn_layer = None
|
||||||
else:
|
else:
|
||||||
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
|
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
|
||||||
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
|
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
|
||||||
attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None
|
attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
|
||||||
layer_fns = replace(layer_fns, attn=attn_layer)
|
layer_fns = replace(layer_fns, attn=attn_layer)
|
||||||
|
|
||||||
# override self-attn layer / args with block local cfg
|
# override self-attn layer / args with block local cfg
|
||||||
if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None:
|
self_attn_set = block_cfg.self_attn_layer is not None
|
||||||
|
if self_attn_set or block_cfg.self_attn_kwargs is not None:
|
||||||
# override attn layer config
|
# override attn layer config
|
||||||
if not block_cfg.self_attn_layer:
|
if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
|
||||||
# empty string for self_attn_layer type will disable attn for this block
|
# empty string for self_attn_layer type will disable attn for this block
|
||||||
self_attn_layer = None
|
self_attn_layer = None
|
||||||
else:
|
else:
|
||||||
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
|
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
|
||||||
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
|
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
|
||||||
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
|
self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
|
||||||
if self_attn_layer is not None else None
|
if self_attn_layer is not None else None
|
||||||
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
|
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
|
||||||
|
|
||||||
|
@ -103,19 +103,21 @@ class HaloAttn(nn.Module):
|
|||||||
- https://arxiv.org/abs/2103.12731
|
- https://arxiv.org/abs/2103.12731
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False):
|
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dim_out = dim_out or dim
|
dim_out = dim_out or dim
|
||||||
assert dim_out % num_heads == 0
|
assert dim_out % num_heads == 0
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head or dim // num_heads
|
||||||
self.dim_qk = num_heads * dim_head
|
self.dim_qk = num_heads * self.dim_head
|
||||||
self.dim_v = dim_out
|
self.dim_v = dim_out
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.halo_size = halo_size
|
self.halo_size = halo_size
|
||||||
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||||
self.scale = self.dim_head ** -0.5
|
self.scale = self.dim_head ** -0.5
|
||||||
|
# stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA)
|
||||||
|
self.stride_tricks = True
|
||||||
|
|
||||||
# FIXME not clear if this stride behaviour is what the paper intended
|
# FIXME not clear if this stride behaviour is what the paper intended
|
||||||
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
||||||
@ -139,28 +141,43 @@ class HaloAttn(nn.Module):
|
|||||||
num_h_blocks = H // self.block_size
|
num_h_blocks = H // self.block_size
|
||||||
num_w_blocks = W // self.block_size
|
num_w_blocks = W // self.block_size
|
||||||
num_blocks = num_h_blocks * num_w_blocks
|
num_blocks = num_h_blocks * num_w_blocks
|
||||||
|
bs_stride = self.block_size // self.stride
|
||||||
|
|
||||||
q = self.q(x)
|
q = self.q(x)
|
||||||
q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
|
# q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap
|
||||||
|
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
||||||
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
||||||
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
||||||
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
||||||
|
|
||||||
kv = self.kv(x)
|
kv = self.kv(x)
|
||||||
# FIXME I 'think' this unfold does what I want it to, but I should investigate
|
|
||||||
|
# generate overlapping windows using either stride tricks (as_strided) or unfold
|
||||||
|
if self.stride_tricks:
|
||||||
|
# this is much faster
|
||||||
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
||||||
|
kv = kv.as_strided((
|
||||||
|
B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
||||||
|
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
||||||
|
else:
|
||||||
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
||||||
|
|
||||||
kv = kv.reshape(
|
kv = kv.reshape(
|
||||||
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
||||||
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
||||||
|
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
||||||
|
|
||||||
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
||||||
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
||||||
|
|
||||||
attn_out = attn_logits.softmax(dim=-1)
|
attn_out = attn_logits.softmax(dim=-1)
|
||||||
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
||||||
attn_out = F.fold(
|
|
||||||
attn_out.reshape(B, -1, num_blocks),
|
# F.fold can be replaced by reshape + permute, slightly faster
|
||||||
(H // self.stride, W // self.stride),
|
# attn_out = F.fold(
|
||||||
kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
|
# attn_out.reshape(B, -1, num_blocks),
|
||||||
|
# (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride)
|
||||||
|
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
||||||
|
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
|
||||||
# B, dim_out, H // stride, W // stride
|
# B, dim_out, H // stride, W // stride
|
||||||
return attn_out
|
return attn_out
|
||||||
|
Loading…
x
Reference in New Issue
Block a user