mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Defaul lambda r=7. Define '26t' stage 4/5 256x256 variants for all of bot/halo/lambda nets for experiment. Add resnet50t for exp. Fix a few comments.
This commit is contained in:
parent
d15ad3e919
commit
e15c3886ba
timm/models
@ -45,15 +45,16 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
# GPU-Efficient (ResNet) weights
|
# GPU-Efficient (ResNet) weights
|
||||||
|
'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256)),
|
||||||
'botnet50t_224': _cfg(url='', fixed_input_size=True),
|
'botnet50t_224': _cfg(url='', fixed_input_size=True),
|
||||||
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
|
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
|
||||||
|
|
||||||
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||||
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||||
'halonet26t': _cfg(url=''),
|
'halonet26t': _cfg(url='', input_size=(3, 256, 256)),
|
||||||
'halonet50t': _cfg(url=''),
|
'halonet50t': _cfg(url=''),
|
||||||
|
|
||||||
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)),
|
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256)),
|
||||||
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
|
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,6 +93,21 @@ def interleave_attn(
|
|||||||
|
|
||||||
model_cfgs = dict(
|
model_cfgs = dict(
|
||||||
|
|
||||||
|
botnet26t=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||||
|
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='bottleneck',
|
||||||
|
self_attn_fixed_size=True,
|
||||||
|
self_attn_kwargs=dict()
|
||||||
|
),
|
||||||
botnet50t=ByoaCfg(
|
botnet50t=ByoaCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
||||||
@ -161,7 +177,7 @@ model_cfgs = dict(
|
|||||||
blocks=(
|
blocks=(
|
||||||
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||||
ByoaBlocksCfg(type='bottle', d=2, c=1024, s=2, gs=0, br=0.25),
|
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||||
),
|
),
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
@ -169,7 +185,7 @@ model_cfgs = dict(
|
|||||||
stem_pool='maxpool',
|
stem_pool='maxpool',
|
||||||
num_features=0,
|
num_features=0,
|
||||||
self_attn_layer='halo',
|
self_attn_layer='halo',
|
||||||
self_attn_kwargs=dict(block_size=7, halo_size=2)
|
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
|
||||||
),
|
),
|
||||||
halonet50t=ByoaCfg(
|
halonet50t=ByoaCfg(
|
||||||
blocks=(
|
blocks=(
|
||||||
@ -370,6 +386,14 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def botnet26t_256(pretrained=False, **kwargs):
|
||||||
|
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
|
||||||
|
"""
|
||||||
|
kwargs.setdefault('img_size', 256)
|
||||||
|
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def botnet50t_224(pretrained=False, **kwargs):
|
def botnet50t_224(pretrained=False, **kwargs):
|
||||||
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage.
|
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage.
|
||||||
|
@ -115,7 +115,7 @@ class HaloAttn(nn.Module):
|
|||||||
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||||
self.scale = self.dim_head ** -0.5
|
self.scale = self.dim_head ** -0.5
|
||||||
|
|
||||||
# FIXME not clear if this stride behaviour is what the paper intended, not really clear
|
# FIXME not clear if this stride behaviour is what the paper intended
|
||||||
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
||||||
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
||||||
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
|
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
|
||||||
@ -139,10 +139,10 @@ class HaloAttn(nn.Module):
|
|||||||
|
|
||||||
kv = self.kv(x)
|
kv = self.kv(x)
|
||||||
# FIXME I 'think' this unfold does what I want it to, but I should investigate
|
# FIXME I 'think' this unfold does what I want it to, but I should investigate
|
||||||
k = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
||||||
k = k.reshape(
|
kv = kv.reshape(
|
||||||
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
||||||
k, v = torch.split(k, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
||||||
|
|
||||||
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
||||||
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
||||||
|
@ -34,7 +34,7 @@ class LambdaLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=5, qkv_bias=False):
|
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim_out = dim_out or dim
|
self.dim_out = dim_out or dim
|
||||||
self.dim_k = dim_head # query depth 'k'
|
self.dim_k = dim_head # query depth 'k'
|
||||||
|
@ -54,6 +54,9 @@ default_cfgs = {
|
|||||||
'resnet50d': _cfg(
|
'resnet50d': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
|
||||||
interpolation='bicubic', first_conv='conv1.0'),
|
interpolation='bicubic', first_conv='conv1.0'),
|
||||||
|
'resnet50t': _cfg(
|
||||||
|
url='',
|
||||||
|
interpolation='bicubic', first_conv='conv1.0'),
|
||||||
'resnet101': _cfg(url='', interpolation='bicubic'),
|
'resnet101': _cfg(url='', interpolation='bicubic'),
|
||||||
'resnet101d': _cfg(
|
'resnet101d': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
|
||||||
@ -706,6 +709,15 @@ def resnet50d(pretrained=False, **kwargs):
|
|||||||
return _create_resnet('resnet50d', pretrained, **model_args)
|
return _create_resnet('resnet50d', pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def resnet50t(pretrained=False, **kwargs):
|
||||||
|
"""Constructs a ResNet-50-T model.
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
|
||||||
|
return _create_resnet('resnet50t', pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet101(pretrained=False, **kwargs):
|
def resnet101(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-101 model.
|
"""Constructs a ResNet-101 model.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user