mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
ByoaNet with bottleneck transformer, lambda resnet, and halo net experiments
This commit is contained in:
parent
21812d33aa
commit
ce62f96d4d
@ -1,3 +1,4 @@
|
|||||||
|
from .byoanet import *
|
||||||
from .byobnet import *
|
from .byobnet import *
|
||||||
from .cspnet import *
|
from .cspnet import *
|
||||||
from .densenet import *
|
from .densenet import *
|
||||||
@ -39,5 +40,4 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
|||||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||||
from .layers import convert_splitbn_model
|
from .layers import convert_splitbn_model
|
||||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||||
from .registry import *
|
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules
|
||||||
|
|
||||||
|
427
timm/models/byoanet.py
Normal file
427
timm/models/byoanet.py
Normal file
@ -0,0 +1,427 @@
|
|||||||
|
""" Bring-Your-Own-Attention Network
|
||||||
|
|
||||||
|
A flexible network w/ dataclass based config for stacking NN blocks including
|
||||||
|
self-attention (or similar) layers.
|
||||||
|
|
||||||
|
Currently used to implement experimential variants of:
|
||||||
|
* Bottleneck Transformers
|
||||||
|
* Lambda ResNets
|
||||||
|
* HaloNets
|
||||||
|
|
||||||
|
Consider all of the models here a WIP and likely to change.
|
||||||
|
|
||||||
|
Hacked together by / copyright Ross Wightman, 2021.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Tuple, List, Optional, Union, Any, Callable
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\
|
||||||
|
reduce_feat_size, register_block, num_groups, LayerFn, _init_weights
|
||||||
|
from .helpers import build_model_with_cfg
|
||||||
|
from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\
|
||||||
|
make_divisible, to_2tuple
|
||||||
|
from .registry import register_model
|
||||||
|
|
||||||
|
__all__ = ['ByoaNet']
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
# GPU-Efficient (ResNet) weights
|
||||||
|
'botnet50t_224': _cfg(url=''),
|
||||||
|
'botnet50t_c4c5_224': _cfg(url=''),
|
||||||
|
|
||||||
|
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
|
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
|
'halonet26t': _cfg(url=''),
|
||||||
|
'halonet50t': _cfg(url=''),
|
||||||
|
|
||||||
|
'lambda_resnet26t': _cfg(url=''),
|
||||||
|
'lambda_resnet50t': _cfg(url=''),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ByoaBlocksCfg(BlocksCfg):
|
||||||
|
# FIXME allow overriding self_attn layer or args per block/stage,
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ByoaCfg(ByobCfg):
|
||||||
|
blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None
|
||||||
|
self_attn_layer: Optional[str] = None
|
||||||
|
self_attn_fixed_size: bool = False
|
||||||
|
self_attn_kwargs: dict = field(default_factory=lambda: dict())
|
||||||
|
|
||||||
|
|
||||||
|
def interleave_attn(
|
||||||
|
types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
|
||||||
|
) -> Tuple[ByoaBlocksCfg]:
|
||||||
|
""" interleave attn blocks
|
||||||
|
"""
|
||||||
|
assert len(types) == 2
|
||||||
|
if isinstance(every, int):
|
||||||
|
every = list(range(0 if first else every, d, every))
|
||||||
|
if not every:
|
||||||
|
every = [d - 1]
|
||||||
|
set(every)
|
||||||
|
blocks = []
|
||||||
|
for i in range(d):
|
||||||
|
block_type = types[1] if i in every else types[0]
|
||||||
|
blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)]
|
||||||
|
return tuple(blocks)
|
||||||
|
|
||||||
|
|
||||||
|
model_cfgs = dict(
|
||||||
|
|
||||||
|
botnet50t=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=6, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='bottleneck',
|
||||||
|
self_attn_fixed_size=True,
|
||||||
|
self_attn_kwargs=dict()
|
||||||
|
),
|
||||||
|
botnet50t_c4c5=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||||
|
(
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=1, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=5, c=1024, s=1, gs=0, br=0.25),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=1, c=2048, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=2, c=2048, s=1, gs=0, br=0.25),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='bottleneck',
|
||||||
|
self_attn_fixed_size=True,
|
||||||
|
self_attn_kwargs=dict()
|
||||||
|
),
|
||||||
|
|
||||||
|
halonet_h1=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='7x7',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='halo',
|
||||||
|
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
||||||
|
),
|
||||||
|
halonet_h1_c4c5=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='halo',
|
||||||
|
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
||||||
|
),
|
||||||
|
halonet26t=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=2, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='halo',
|
||||||
|
self_attn_kwargs=dict(block_size=7, halo_size=2)
|
||||||
|
),
|
||||||
|
halonet50t=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=6, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='halo',
|
||||||
|
self_attn_kwargs=dict(block_size=7, halo_size=2)
|
||||||
|
),
|
||||||
|
|
||||||
|
lambda_resnet26t=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||||
|
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='lambda',
|
||||||
|
self_attn_kwargs=dict()
|
||||||
|
),
|
||||||
|
lambda_resnet50t=ByoaCfg(
|
||||||
|
blocks=(
|
||||||
|
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||||
|
interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||||
|
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||||
|
),
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='tiered',
|
||||||
|
stem_pool='maxpool',
|
||||||
|
num_features=0,
|
||||||
|
self_attn_layer='lambda',
|
||||||
|
self_attn_kwargs=dict()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ByoaLayerFn(LayerFn):
|
||||||
|
self_attn: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttnBlock(nn.Module):
|
||||||
|
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
||||||
|
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
|
||||||
|
layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.):
|
||||||
|
super(SelfAttnBlock, self).__init__()
|
||||||
|
assert layers is not None
|
||||||
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
|
groups = num_groups(group_size, mid_chs)
|
||||||
|
|
||||||
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
|
self.shortcut = create_downsample(
|
||||||
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
||||||
|
apply_act=False, layers=layers)
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
||||||
|
if extra_conv:
|
||||||
|
self.conv2_kxk = layers.conv_norm_act(
|
||||||
|
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
||||||
|
groups=groups, drop_block=drop_block)
|
||||||
|
stride = 1 # striding done via conv if enabled
|
||||||
|
else:
|
||||||
|
self.conv2_kxk = nn.Identity()
|
||||||
|
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
|
||||||
|
# FIXME need to dilate self attn to have dilated network support, moop moop
|
||||||
|
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
|
||||||
|
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
|
||||||
|
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
||||||
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
|
def init_weights(self, zero_init_last_bn=False):
|
||||||
|
if zero_init_last_bn:
|
||||||
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shortcut = self.shortcut(x)
|
||||||
|
|
||||||
|
x = self.conv1_1x1(x)
|
||||||
|
x = self.conv2_kxk(x)
|
||||||
|
x = self.self_attn(x)
|
||||||
|
x = self.post_attn(x)
|
||||||
|
x = self.conv3_1x1(x)
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
x = self.act(x + shortcut)
|
||||||
|
return x
|
||||||
|
|
||||||
|
register_block('self_attn', SelfAttnBlock)
|
||||||
|
|
||||||
|
|
||||||
|
def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None):
|
||||||
|
if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size:
|
||||||
|
assert feat_size is not None
|
||||||
|
block_kwargs['feat_size'] = feat_size
|
||||||
|
return block_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_fns(cfg: ByoaCfg):
|
||||||
|
act = get_act_layer(cfg.act_layer)
|
||||||
|
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
|
||||||
|
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
|
||||||
|
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||||
|
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
|
||||||
|
layer_fn = ByoaLayerFn(
|
||||||
|
conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
|
||||||
|
return layer_fn
|
||||||
|
|
||||||
|
|
||||||
|
class ByoaNet(nn.Module):
|
||||||
|
""" 'Bring-your-own-attention' Net
|
||||||
|
|
||||||
|
A ResNet inspired backbone that supports interleaving traditional residual blocks with
|
||||||
|
'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention
|
||||||
|
or similar module.
|
||||||
|
|
||||||
|
FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but
|
||||||
|
torchscript limitations prevent sensible inheritance overrides.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
|
||||||
|
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
layers = get_layer_fns(cfg)
|
||||||
|
feat_size = to_2tuple(img_size) if img_size is not None else None
|
||||||
|
|
||||||
|
self.feature_info = []
|
||||||
|
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
||||||
|
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
|
||||||
|
self.feature_info.extend(stem_feat[:-1])
|
||||||
|
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
|
||||||
|
|
||||||
|
self.stages, stage_feat = create_byob_stages(
|
||||||
|
cfg, drop_path_rate, output_stride, stem_feat[-1],
|
||||||
|
feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args)
|
||||||
|
self.feature_info.extend(stage_feat[:-1])
|
||||||
|
|
||||||
|
prev_chs = stage_feat[-1]['num_chs']
|
||||||
|
if cfg.num_features:
|
||||||
|
self.num_features = int(round(cfg.width_factor * cfg.num_features))
|
||||||
|
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
|
||||||
|
else:
|
||||||
|
self.num_features = prev_chs
|
||||||
|
self.final_conv = nn.Identity()
|
||||||
|
self.feature_info += [
|
||||||
|
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
|
||||||
|
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
_init_weights(m, n)
|
||||||
|
for m in self.modules():
|
||||||
|
# call each block's weight init for block-specific overrides to init above
|
||||||
|
if hasattr(m, 'init_weights'):
|
||||||
|
m.init_weights(zero_init_last_bn=zero_init_last_bn)
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.stem(x)
|
||||||
|
x = self.stages(x)
|
||||||
|
x = self.final_conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||||
|
return build_model_with_cfg(
|
||||||
|
ByoaNet, variant, pretrained,
|
||||||
|
default_cfg=default_cfgs[variant],
|
||||||
|
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||||
|
feature_cfg=dict(flatten_sequential=True),
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def botnet50t_224(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
kwargs.setdefault('img_size', 224)
|
||||||
|
return _create_byoanet('botnet50t_224', 'botnet50t', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def botnet50t_c4c5_224(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
kwargs.setdefault('img_size', 224)
|
||||||
|
return _create_byoanet('botnet50t_c4c5_224', 'botnet50t_c4c5', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def halonet_h1(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def halonet_h1_c4c5(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
return _create_byoanet('halonet_h1_c4c5', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def halonet26t(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def halonet50t(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
return _create_byoanet('halonet50t', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def lambda_resnet26t(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def lambda_resnet50t(pretrained=False, **kwargs):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs)
|
@ -25,9 +25,9 @@ above nets that include attention.
|
|||||||
Hacked together by / copyright Ross Wightman, 2021.
|
Hacked together by / copyright Ross Wightman, 2021.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field, replace
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Tuple, Dict, Optional, Union, Any, Callable
|
from typing import Tuple, List, Optional, Union, Any, Callable, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -35,11 +35,11 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from .layers import ClassifierHead, ConvBnAct, DropPath, AvgPool2dSame, \
|
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||||
create_conv2d, get_act_layer, get_attn, convert_norm_act, make_divisible
|
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg']
|
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block']
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
@ -98,20 +98,22 @@ class BlocksCfg:
|
|||||||
s: int = 2 # stride of stage (first block)
|
s: int = 2 # stride of stage (first block)
|
||||||
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
|
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
|
||||||
br: float = 1. # bottleneck-ratio of blocks in stage
|
br: float = 1. # bottleneck-ratio of blocks in stage
|
||||||
|
no_attn: bool = True # disable channel attn (ie SE) when layer is set for model
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ByobCfg:
|
class ByobCfg:
|
||||||
blocks: Tuple[BlocksCfg, ...]
|
blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...]
|
||||||
downsample: str = 'conv1x1'
|
downsample: str = 'conv1x1'
|
||||||
stem_type: str = '3x3'
|
stem_type: str = '3x3'
|
||||||
|
stem_pool: str = ''
|
||||||
stem_chs: int = 32
|
stem_chs: int = 32
|
||||||
width_factor: float = 1.0
|
width_factor: float = 1.0
|
||||||
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
|
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
|
||||||
zero_init_last_bn: bool = True
|
zero_init_last_bn: bool = True
|
||||||
|
|
||||||
act_layer: str = 'relu'
|
act_layer: str = 'relu'
|
||||||
norm_layer: nn.Module = nn.BatchNorm2d
|
norm_layer: str = 'batchnorm'
|
||||||
attn_layer: Optional[str] = None
|
attn_layer: Optional[str] = None
|
||||||
attn_kwargs: dict = field(default_factory=lambda: dict())
|
attn_kwargs: dict = field(default_factory=lambda: dict())
|
||||||
|
|
||||||
@ -201,17 +203,29 @@ model_cfgs = dict(
|
|||||||
stem_type='rep',
|
stem_type='rep',
|
||||||
stem_chs=64,
|
stem_chs=64,
|
||||||
),
|
),
|
||||||
|
|
||||||
|
resnet52q=ByobCfg(
|
||||||
|
blocks=(
|
||||||
|
BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
||||||
|
BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
|
||||||
|
BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
|
||||||
|
BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
|
||||||
|
),
|
||||||
|
stem_chs=128,
|
||||||
|
stem_type='quad',
|
||||||
|
num_features=2048,
|
||||||
|
act_layer='silu',
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _na_args(cfg: dict):
|
def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]:
|
||||||
return dict(
|
if not isinstance(stage_blocks_cfg, Sequence):
|
||||||
norm_layer=cfg.get('norm_layer', nn.BatchNorm2d),
|
stage_blocks_cfg = (stage_blocks_cfg,)
|
||||||
act_layer=cfg.get('act_layer', nn.ReLU))
|
block_cfgs = []
|
||||||
|
for i, cfg in enumerate(stage_blocks_cfg):
|
||||||
|
block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
|
||||||
def _ex_tuple(cfg: dict, *names):
|
return block_cfgs
|
||||||
return tuple([cfg.get(n, None) for n in names])
|
|
||||||
|
|
||||||
|
|
||||||
def num_groups(group_size, channels):
|
def num_groups(group_size, channels):
|
||||||
@ -223,27 +237,36 @@ def num_groups(group_size, channels):
|
|||||||
return channels // group_size
|
return channels // group_size
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LayerFn:
|
||||||
|
conv_norm_act: Callable = ConvBnAct
|
||||||
|
norm_act: Callable = BatchNormAct2d
|
||||||
|
act: Callable = nn.ReLU
|
||||||
|
attn: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
class DownsampleAvg(nn.Module):
|
class DownsampleAvg(nn.Module):
|
||||||
def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, norm_layer=None, act_layer=None):
|
def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, layers: LayerFn = None):
|
||||||
""" AvgPool Downsampling as in 'D' ResNet variants."""
|
""" AvgPool Downsampling as in 'D' ResNet variants."""
|
||||||
super(DownsampleAvg, self).__init__()
|
super(DownsampleAvg, self).__init__()
|
||||||
|
layers = layers or LayerFn()
|
||||||
avg_stride = stride if dilation == 1 else 1
|
avg_stride = stride if dilation == 1 else 1
|
||||||
if stride > 1 or dilation > 1:
|
if stride > 1 or dilation > 1:
|
||||||
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||||
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||||
else:
|
else:
|
||||||
self.pool = nn.Identity()
|
self.pool = nn.Identity()
|
||||||
self.conv = ConvBnAct(in_chs, out_chs, 1, apply_act=apply_act, norm_layer=norm_layer, act_layer=act_layer)
|
self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv(self.pool(x))
|
return self.conv(self.pool(x))
|
||||||
|
|
||||||
|
|
||||||
def create_downsample(type, **kwargs):
|
def create_downsample(downsample_type, layers: LayerFn, **kwargs):
|
||||||
if type == 'avg':
|
if downsample_type == 'avg':
|
||||||
return DownsampleAvg(**kwargs)
|
return DownsampleAvg(**kwargs)
|
||||||
else:
|
else:
|
||||||
return ConvBnAct(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
|
return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(nn.Module):
|
||||||
@ -252,28 +275,25 @@ class BasicBlock(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
|
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
|
||||||
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
layer_cfg = layer_cfg or {}
|
layers = layers or LayerFn()
|
||||||
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
||||||
layer_args = _na_args(layer_cfg)
|
|
||||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
groups = num_groups(group_size, mid_chs)
|
groups = num_groups(group_size, mid_chs)
|
||||||
|
|
||||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
self.shortcut = create_downsample(
|
self.shortcut = create_downsample(
|
||||||
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
||||||
apply_act=False, **layer_args)
|
apply_act=False, layers=layers)
|
||||||
else:
|
else:
|
||||||
self.shortcut = nn.Identity()
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
self.conv1_kxk = ConvBnAct(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], **layer_args)
|
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
|
||||||
self.conv2_kxk = ConvBnAct(
|
self.conv2_kxk = layers.conv_norm_act(
|
||||||
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups,
|
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
|
||||||
drop_block=drop_block, apply_act=False, **layer_args)
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
||||||
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last_bn=False):
|
def init_weights(self, zero_init_last_bn=False):
|
||||||
if zero_init_last_bn:
|
if zero_init_last_bn:
|
||||||
@ -297,29 +317,27 @@ class BottleneckBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
||||||
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
|
||||||
super(BottleneckBlock, self).__init__()
|
super(BottleneckBlock, self).__init__()
|
||||||
layer_cfg = layer_cfg or {}
|
layers = layers or LayerFn()
|
||||||
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
||||||
layer_args = _na_args(layer_cfg)
|
|
||||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
groups = num_groups(group_size, mid_chs)
|
groups = num_groups(group_size, mid_chs)
|
||||||
|
|
||||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
self.shortcut = create_downsample(
|
self.shortcut = create_downsample(
|
||||||
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
||||||
apply_act=False, **layer_args)
|
apply_act=False, layers=layers)
|
||||||
else:
|
else:
|
||||||
self.shortcut = nn.Identity()
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args)
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
||||||
self.conv2_kxk = ConvBnAct(
|
self.conv2_kxk = layers.conv_norm_act(
|
||||||
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
||||||
groups=groups, drop_block=drop_block, **layer_args)
|
groups=groups, drop_block=drop_block)
|
||||||
self.attn = nn.Identity() if attn_layer is None else attn_layer(mid_chs)
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs)
|
||||||
self.conv3_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args)
|
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last_bn=False):
|
def init_weights(self, zero_init_last_bn=False):
|
||||||
if zero_init_last_bn:
|
if zero_init_last_bn:
|
||||||
@ -350,28 +368,26 @@ class DarkBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
||||||
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
||||||
super(DarkBlock, self).__init__()
|
super(DarkBlock, self).__init__()
|
||||||
layer_cfg = layer_cfg or {}
|
layers = layers or LayerFn()
|
||||||
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
||||||
layer_args = _na_args(layer_cfg)
|
|
||||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
groups = num_groups(group_size, mid_chs)
|
groups = num_groups(group_size, mid_chs)
|
||||||
|
|
||||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
self.shortcut = create_downsample(
|
self.shortcut = create_downsample(
|
||||||
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
||||||
apply_act=False, **layer_args)
|
apply_act=False, layers=layers)
|
||||||
else:
|
else:
|
||||||
self.shortcut = nn.Identity()
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
self.conv1_1x1 = ConvBnAct(in_chs, mid_chs, 1, **layer_args)
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
||||||
self.conv2_kxk = ConvBnAct(
|
self.conv2_kxk = layers.conv_norm_act(
|
||||||
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
|
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
|
||||||
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
|
groups=groups, drop_block=drop_block, apply_act=False)
|
||||||
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last_bn=False):
|
def init_weights(self, zero_init_last_bn=False):
|
||||||
if zero_init_last_bn:
|
if zero_init_last_bn:
|
||||||
@ -399,28 +415,26 @@ class EdgeBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
||||||
downsample='avg', linear_out=False, layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
||||||
super(EdgeBlock, self).__init__()
|
super(EdgeBlock, self).__init__()
|
||||||
layer_cfg = layer_cfg or {}
|
layers = layers or LayerFn()
|
||||||
act_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'attn_layer')
|
|
||||||
layer_args = _na_args(layer_cfg)
|
|
||||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
groups = num_groups(group_size, mid_chs)
|
groups = num_groups(group_size, mid_chs)
|
||||||
|
|
||||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
self.shortcut = create_downsample(
|
self.shortcut = create_downsample(
|
||||||
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
||||||
apply_act=False, **layer_args)
|
apply_act=False, layers=layers)
|
||||||
else:
|
else:
|
||||||
self.shortcut = nn.Identity()
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
self.conv1_kxk = ConvBnAct(
|
self.conv1_kxk = layers.conv_norm_act(
|
||||||
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
||||||
groups=groups, drop_block=drop_block, **layer_args)
|
groups=groups, drop_block=drop_block)
|
||||||
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
||||||
self.conv2_1x1 = ConvBnAct(mid_chs, out_chs, 1, apply_act=False, **layer_args)
|
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
self.act = nn.Identity() if linear_out else act_layer(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last_bn=False):
|
def init_weights(self, zero_init_last_bn=False):
|
||||||
if zero_init_last_bn:
|
if zero_init_last_bn:
|
||||||
@ -446,23 +460,20 @@ class RepVggBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
||||||
downsample='', layer_cfg=None, drop_block=None, drop_path_rate=0.):
|
downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
|
||||||
super(RepVggBlock, self).__init__()
|
super(RepVggBlock, self).__init__()
|
||||||
layer_cfg = layer_cfg or {}
|
layers = layers or LayerFn()
|
||||||
act_layer, norm_layer, attn_layer = _ex_tuple(layer_cfg, 'act_layer', 'norm_layer', 'attn_layer')
|
|
||||||
norm_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer)
|
|
||||||
layer_args = _na_args(layer_cfg)
|
|
||||||
groups = num_groups(group_size, in_chs)
|
groups = num_groups(group_size, in_chs)
|
||||||
|
|
||||||
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
||||||
self.identity = norm_layer(out_chs, apply_act=False) if use_ident else None
|
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
||||||
self.conv_kxk = ConvBnAct(
|
self.conv_kxk = layers.conv_norm_act(
|
||||||
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
|
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
|
||||||
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
|
groups=groups, drop_block=drop_block, apply_act=False)
|
||||||
self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args)
|
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
|
||||||
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
|
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
||||||
self.act = act_layer(inplace=True)
|
self.act = layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last_bn=False):
|
def init_weights(self, zero_init_last_bn=False):
|
||||||
# NOTE this init overrides that base model init with specific changes for the block type
|
# NOTE this init overrides that base model init with specific changes for the block type
|
||||||
@ -504,33 +515,200 @@ def create_block(block: Union[str, nn.Module], **kwargs):
|
|||||||
return _block_registry[block](**kwargs)
|
return _block_registry[block](**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_stem(in_chs, out_chs, stem_type='', layer_cfg=None):
|
# class Stem(nn.Module):
|
||||||
layer_cfg = layer_cfg or {}
|
#
|
||||||
layer_args = _na_args(layer_cfg)
|
# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
|
||||||
assert stem_type in ('', 'deep', 'deep_tiered', '3x3', '7x7', 'rep')
|
# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
|
||||||
if 'deep' in stem_type:
|
# super().__init__()
|
||||||
# 3 deep 3x3 conv stack
|
# assert stride in (2, 4)
|
||||||
stem = OrderedDict()
|
# if pool:
|
||||||
stem_chs = (out_chs // 2, out_chs // 2)
|
# assert stride == 4
|
||||||
if 'tiered' in stem_type:
|
# layers = layers or LayerFn()
|
||||||
stem_chs = (3 * stem_chs[0] // 4, stem_chs[1])
|
#
|
||||||
norm_layer, act_layer = _ex_tuple(layer_args, 'norm_layer', 'act_layer')
|
# if isinstance(out_chs, (list, tuple)):
|
||||||
stem['conv1'] = create_conv2d(in_chs, stem_chs[0], kernel_size=3, stride=2)
|
# num_rep = len(out_chs)
|
||||||
stem['conv2'] = create_conv2d(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
|
# stem_chs = out_chs
|
||||||
stem['conv3'] = create_conv2d(stem_chs[1], out_chs, kernel_size=3, stride=1)
|
# else:
|
||||||
norm_act_layer = convert_norm_act(norm_layer=norm_layer, act_layer=act_layer)
|
# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
|
||||||
stem['na'] = norm_act_layer(out_chs)
|
#
|
||||||
stem = nn.Sequential(stem)
|
# self.stride = stride
|
||||||
|
# stem_strides = [2] + [1] * (num_rep - 1)
|
||||||
|
# if stride == 4 and not pool:
|
||||||
|
# # set last conv in stack to be strided if stride == 4 and no pooling layer
|
||||||
|
# stem_strides[-1] = 2
|
||||||
|
#
|
||||||
|
# num_act = num_rep if num_act is None else num_act
|
||||||
|
# # if num_act < num_rep, first convs in stack won't have bn + act
|
||||||
|
# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
||||||
|
# prev_chs = in_chs
|
||||||
|
# convs = []
|
||||||
|
# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
||||||
|
# layer_fn = layers.conv_norm_act if na else create_conv2d
|
||||||
|
# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
||||||
|
# prev_chs = ch
|
||||||
|
# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0]
|
||||||
|
#
|
||||||
|
# if not pool:
|
||||||
|
# self.pool = nn.Identity()
|
||||||
|
# elif 'max' in pool.lower():
|
||||||
|
# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity()
|
||||||
|
# else:
|
||||||
|
# assert False, "Unknown pooling type"
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# x = self.conv(x)
|
||||||
|
# x = self.pool(x)
|
||||||
|
# return x
|
||||||
|
|
||||||
|
|
||||||
|
class Stem(nn.Sequential):
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
|
||||||
|
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
|
||||||
|
super().__init__()
|
||||||
|
assert stride in (2, 4)
|
||||||
|
layers = layers or LayerFn()
|
||||||
|
|
||||||
|
if isinstance(out_chs, (list, tuple)):
|
||||||
|
num_rep = len(out_chs)
|
||||||
|
stem_chs = out_chs
|
||||||
|
else:
|
||||||
|
stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.feature_info = [] # track intermediate features
|
||||||
|
prev_feat = ''
|
||||||
|
stem_strides = [2] + [1] * (num_rep - 1)
|
||||||
|
if stride == 4 and not pool:
|
||||||
|
# set last conv in stack to be strided if stride == 4 and no pooling layer
|
||||||
|
stem_strides[-1] = 2
|
||||||
|
|
||||||
|
num_act = num_rep if num_act is None else num_act
|
||||||
|
# if num_act < num_rep, first convs in stack won't have bn + act
|
||||||
|
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
||||||
|
prev_chs = in_chs
|
||||||
|
curr_stride = 1
|
||||||
|
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
||||||
|
layer_fn = layers.conv_norm_act if na else create_conv2d
|
||||||
|
conv_name = f'conv{i + 1}'
|
||||||
|
if i > 0 and s > 1:
|
||||||
|
self.feature_info.append(dict(num_chs=ch, reduction=curr_stride, module=prev_feat))
|
||||||
|
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
||||||
|
prev_chs = ch
|
||||||
|
curr_stride *= s
|
||||||
|
prev_feat = conv_name
|
||||||
|
|
||||||
|
if 'max' in pool.lower():
|
||||||
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
||||||
|
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
|
||||||
|
curr_stride *= 2
|
||||||
|
prev_feat = 'pool'
|
||||||
|
|
||||||
|
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
||||||
|
assert curr_stride == stride
|
||||||
|
|
||||||
|
|
||||||
|
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
|
||||||
|
layers = layers or LayerFn()
|
||||||
|
assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3')
|
||||||
|
if 'quad' in stem_type:
|
||||||
|
# based on NFNet stem, stack of 4 3x3 convs
|
||||||
|
num_act = 2 if 'quad2' in stem_type else None
|
||||||
|
stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
|
||||||
|
elif 'tiered' in stem_type:
|
||||||
|
# 3x3 stack of 3 convs as in my ResNet-T
|
||||||
|
stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
|
||||||
|
elif 'deep' in stem_type:
|
||||||
|
# 3x3 stack of 3 convs as in ResNet-D
|
||||||
|
stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
|
||||||
|
elif 'rep' in stem_type:
|
||||||
|
stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
|
||||||
elif '7x7' in stem_type:
|
elif '7x7' in stem_type:
|
||||||
# 7x7 stem conv as in ResNet
|
# 7x7 stem conv as in ResNet
|
||||||
stem = ConvBnAct(in_chs, out_chs, 7, stride=2, **layer_args)
|
if pool_type:
|
||||||
elif 'rep' in stem_type:
|
stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
|
||||||
stem = RepVggBlock(in_chs, out_chs, stride=2, layer_cfg=layer_cfg)
|
else:
|
||||||
|
stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
|
||||||
else:
|
else:
|
||||||
# 3x3 stem conv as in RegNet
|
# 3x3 stem conv as in RegNet is the default
|
||||||
stem = ConvBnAct(in_chs, out_chs, 3, stride=2, **layer_args)
|
if pool_type:
|
||||||
|
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
|
||||||
|
else:
|
||||||
|
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
|
||||||
|
|
||||||
return stem
|
if isinstance(stem, Stem):
|
||||||
|
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
|
||||||
|
else:
|
||||||
|
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
|
||||||
|
return stem, feature_info
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_feat_size(feat_size, stride=2):
|
||||||
|
return None if feat_size is None else tuple([s // stride for s in feat_size])
|
||||||
|
|
||||||
|
|
||||||
|
def create_byob_stages(
|
||||||
|
cfg, drop_path_rate, output_stride, stem_feat,
|
||||||
|
feat_size=None, layers=None, extra_args_fn=None):
|
||||||
|
layers = layers or LayerFn()
|
||||||
|
feature_info = []
|
||||||
|
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
|
||||||
|
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
|
||||||
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||||
|
dilation = 1
|
||||||
|
net_stride = stem_feat['reduction']
|
||||||
|
prev_chs = stem_feat['num_chs']
|
||||||
|
prev_feat = stem_feat
|
||||||
|
stages = []
|
||||||
|
for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
|
||||||
|
stride = stage_block_cfgs[0].s
|
||||||
|
if stride != 1 and prev_feat:
|
||||||
|
feature_info.append(prev_feat)
|
||||||
|
if net_stride >= output_stride and stride > 1:
|
||||||
|
dilation *= stride
|
||||||
|
stride = 1
|
||||||
|
net_stride *= stride
|
||||||
|
first_dilation = 1 if dilation in (1, 2) else 2
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for block_idx, block_cfg in enumerate(stage_block_cfgs):
|
||||||
|
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
|
||||||
|
group_size = block_cfg.gs
|
||||||
|
if isinstance(group_size, Callable):
|
||||||
|
group_size = group_size(out_chs, block_idx)
|
||||||
|
block_kwargs = dict( # Blocks used in this model must accept these arguments
|
||||||
|
in_chs=prev_chs,
|
||||||
|
out_chs=out_chs,
|
||||||
|
stride=stride if block_idx == 0 else 1,
|
||||||
|
dilation=(first_dilation, dilation),
|
||||||
|
group_size=group_size,
|
||||||
|
bottle_ratio=block_cfg.br,
|
||||||
|
downsample=cfg.downsample,
|
||||||
|
drop_path_rate=dpr[stage_idx][block_idx],
|
||||||
|
layers=layers,
|
||||||
|
)
|
||||||
|
if extra_args_fn is not None:
|
||||||
|
extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size)
|
||||||
|
blocks += [create_block(block_cfg.type, **block_kwargs)]
|
||||||
|
first_dilation = dilation
|
||||||
|
prev_chs = out_chs
|
||||||
|
if stride > 1 and block_idx == 0:
|
||||||
|
feat_size = reduce_feat_size(feat_size, stride)
|
||||||
|
|
||||||
|
stages += [nn.Sequential(*blocks)]
|
||||||
|
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
|
||||||
|
|
||||||
|
feature_info.append(prev_feat)
|
||||||
|
return nn.Sequential(*stages), feature_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_fns(cfg: ByobCfg):
|
||||||
|
act = get_act_layer(cfg.act_layer)
|
||||||
|
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
|
||||||
|
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
|
||||||
|
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||||
|
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn)
|
||||||
|
return layer_fn
|
||||||
|
|
||||||
|
|
||||||
class ByobNet(nn.Module):
|
class ByobNet(nn.Module):
|
||||||
@ -546,79 +724,30 @@ class ByobNet(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
norm_layer = cfg.norm_layer
|
layers = get_layer_fns(cfg)
|
||||||
act_layer = get_act_layer(cfg.act_layer)
|
|
||||||
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
|
||||||
layer_cfg = dict(norm_layer=norm_layer, act_layer=act_layer, attn_layer=attn_layer)
|
|
||||||
|
|
||||||
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
|
||||||
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, layer_cfg=layer_cfg)
|
|
||||||
|
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
depths = [bc.d for bc in cfg.blocks]
|
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
||||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
|
||||||
prev_name = 'stem'
|
self.feature_info.extend(stem_feat[:-1])
|
||||||
prev_chs = stem_chs
|
|
||||||
net_stride = 2
|
|
||||||
dilation = 1
|
|
||||||
stages = []
|
|
||||||
for stage_idx, block_cfg in enumerate(cfg.blocks):
|
|
||||||
stride = block_cfg.s
|
|
||||||
if stride != 1:
|
|
||||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=prev_name))
|
|
||||||
if net_stride >= output_stride and stride > 1:
|
|
||||||
dilation *= stride
|
|
||||||
stride = 1
|
|
||||||
net_stride *= stride
|
|
||||||
first_dilation = 1 if dilation in (1, 2) else 2
|
|
||||||
|
|
||||||
blocks = []
|
self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers)
|
||||||
for block_idx in range(block_cfg.d):
|
self.feature_info.extend(stage_feat[:-1])
|
||||||
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
|
|
||||||
group_size = block_cfg.gs
|
|
||||||
if isinstance(group_size, Callable):
|
|
||||||
group_size = group_size(out_chs, block_idx)
|
|
||||||
block_kwargs = dict( # Blocks used in this model must accept these arguments
|
|
||||||
in_chs=prev_chs,
|
|
||||||
out_chs=out_chs,
|
|
||||||
stride=stride if block_idx == 0 else 1,
|
|
||||||
dilation=(first_dilation, dilation),
|
|
||||||
group_size=group_size,
|
|
||||||
bottle_ratio=block_cfg.br,
|
|
||||||
downsample=cfg.downsample,
|
|
||||||
drop_path_rate=dpr[stage_idx][block_idx],
|
|
||||||
layer_cfg=layer_cfg,
|
|
||||||
)
|
|
||||||
blocks += [create_block(block_cfg.type, **block_kwargs)]
|
|
||||||
first_dilation = dilation
|
|
||||||
prev_chs = out_chs
|
|
||||||
stages += [nn.Sequential(*blocks)]
|
|
||||||
prev_name = f'stages.{stage_idx}'
|
|
||||||
self.stages = nn.Sequential(*stages)
|
|
||||||
|
|
||||||
|
prev_chs = stage_feat[-1]['num_chs']
|
||||||
if cfg.num_features:
|
if cfg.num_features:
|
||||||
self.num_features = int(round(cfg.width_factor * cfg.num_features))
|
self.num_features = int(round(cfg.width_factor * cfg.num_features))
|
||||||
self.final_conv = ConvBnAct(prev_chs, self.num_features, 1, **_na_args(layer_cfg))
|
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
|
||||||
else:
|
else:
|
||||||
self.num_features = prev_chs
|
self.num_features = prev_chs
|
||||||
self.final_conv = nn.Identity()
|
self.final_conv = nn.Identity()
|
||||||
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_conv')]
|
self.feature_info += [
|
||||||
|
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
|
||||||
|
|
||||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
_init_weights(m, n)
|
||||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
fan_out //= m.groups
|
|
||||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
|
||||||
nn.init.zeros_(m.bias)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
nn.init.ones_(m.weight)
|
|
||||||
nn.init.zeros_(m.bias)
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
# call each block's weight init for block-specific overrides to init above
|
# call each block's weight init for block-specific overrides to init above
|
||||||
if hasattr(m, 'init_weights'):
|
if hasattr(m, 'init_weights'):
|
||||||
@ -642,6 +771,22 @@ class ByobNet(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _init_weights(m, n=''):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
def _create_byobnet(variant, pretrained=False, **kwargs):
|
def _create_byobnet(variant, pretrained=False, **kwargs):
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
ByobNet, variant, pretrained,
|
ByobNet, variant, pretrained,
|
||||||
|
@ -13,6 +13,7 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn
|
|||||||
from .create_attn import get_attn, create_attn
|
from .create_attn import get_attn, create_attn
|
||||||
from .create_conv2d import create_conv2d
|
from .create_conv2d import create_conv2d
|
||||||
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
|
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
|
||||||
|
from .create_self_attn import get_self_attn, create_self_attn
|
||||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||||
from .eca import EcaModule, CecaModule
|
from .eca import EcaModule, CecaModule
|
||||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
@ -20,6 +21,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
|||||||
from .inplace_abn import InplaceAbn
|
from .inplace_abn import InplaceAbn
|
||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
|
from .norm import GroupNorm
|
||||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
|
120
timm/models/layers/bottleneck_attn.py
Normal file
120
timm/models/layers/bottleneck_attn.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
""" Bottleneck Self Attention (Bottleneck Transformers)
|
||||||
|
|
||||||
|
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
|
||||||
|
|
||||||
|
@misc{2101.11605,
|
||||||
|
Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
|
||||||
|
Title = {Bottleneck Transformers for Visual Recognition},
|
||||||
|
Year = {2021},
|
||||||
|
}
|
||||||
|
|
||||||
|
Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
|
||||||
|
|
||||||
|
This impl is a WIP but given that it is based on the ref gist likely not too far off.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .helpers import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
|
""" Compute relative logits along one dimension
|
||||||
|
|
||||||
|
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
|
||||||
|
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: (batch, heads, height, width, dim)
|
||||||
|
rel_k: (2 * width - 1, dim)
|
||||||
|
permute_mask: permute output dim according to this
|
||||||
|
"""
|
||||||
|
B, H, W, dim = q.shape
|
||||||
|
x = (q @ rel_k.transpose(-1, -2))
|
||||||
|
x = x.reshape(-1, W, 2 * W -1)
|
||||||
|
|
||||||
|
# pad to shift from relative to absolute indexing
|
||||||
|
x_pad = F.pad(x, [0, 1]).flatten(1)
|
||||||
|
x_pad = F.pad(x_pad, [0, W - 1])
|
||||||
|
|
||||||
|
# reshape and slice out the padded elements
|
||||||
|
x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)
|
||||||
|
x = x_pad[:, :W, W - 1:]
|
||||||
|
|
||||||
|
# reshape and tile
|
||||||
|
x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)
|
||||||
|
return x.permute(permute_mask)
|
||||||
|
|
||||||
|
|
||||||
|
class PosEmbedRel(nn.Module):
|
||||||
|
""" Relative Position Embedding
|
||||||
|
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
|
||||||
|
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
|
||||||
|
"""
|
||||||
|
def __init__(self, feat_size, dim_head, scale):
|
||||||
|
super().__init__()
|
||||||
|
self.height, self.width = to_2tuple(feat_size)
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.scale = scale
|
||||||
|
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale)
|
||||||
|
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
|
||||||
|
|
||||||
|
def forward(self, q):
|
||||||
|
B, num_heads, HW, _ = q.shape
|
||||||
|
|
||||||
|
# relative logits in width dimension.
|
||||||
|
q = q.reshape(B * num_heads, self.height, self.width, -1)
|
||||||
|
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
|
||||||
|
|
||||||
|
# relative logits in height dimension.
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
|
||||||
|
|
||||||
|
rel_logits = rel_logits_h + rel_logits_w
|
||||||
|
rel_logits = rel_logits.reshape(B, num_heads, HW, HW)
|
||||||
|
return rel_logits
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckAttn(nn.Module):
|
||||||
|
""" Bottleneck Attention
|
||||||
|
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
|
||||||
|
"""
|
||||||
|
def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False):
|
||||||
|
super().__init__()
|
||||||
|
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
|
||||||
|
dim_out = dim_out or dim
|
||||||
|
assert dim_out % num_heads == 0
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dim_out = dim_out
|
||||||
|
self.dim_head = dim_out // num_heads
|
||||||
|
self.scale = self.dim_head ** -0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias)
|
||||||
|
|
||||||
|
# NOTE I'm only supporting relative pos embedding for now
|
||||||
|
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale)
|
||||||
|
|
||||||
|
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
assert H == self.pos_embed.height and W == self.pos_embed.width
|
||||||
|
|
||||||
|
x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W
|
||||||
|
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
|
||||||
|
q, k, v = torch.split(x, self.num_heads, dim=1)
|
||||||
|
|
||||||
|
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
|
||||||
|
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
|
||||||
|
|
||||||
|
attn_out = attn_logits.softmax(dim = -1)
|
||||||
|
attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
|
||||||
|
attn_out = self.pool(attn_out)
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
|
17
timm/models/layers/create_self_attn.py
Normal file
17
timm/models/layers/create_self_attn.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from .bottleneck_attn import BottleneckAttn
|
||||||
|
from .halo_attn import HaloAttn
|
||||||
|
from .lambda_layer import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
|
def get_self_attn(attn_type):
|
||||||
|
if attn_type == 'bottleneck':
|
||||||
|
return BottleneckAttn
|
||||||
|
elif attn_type == 'halo':
|
||||||
|
return HaloAttn
|
||||||
|
elif attn_type == 'lambda':
|
||||||
|
return LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
|
def create_self_attn(attn_type, dim, stride=1, **kwargs):
|
||||||
|
attn_fn = get_self_attn(attn_type)
|
||||||
|
return attn_fn(dim, stride=stride, **kwargs)
|
157
timm/models/layers/halo_attn.py
Normal file
157
timm/models/layers/halo_attn.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
""" Halo Self Attention
|
||||||
|
|
||||||
|
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
|
||||||
|
- https://arxiv.org/abs/2103.12731
|
||||||
|
|
||||||
|
@misc{2103.12731,
|
||||||
|
Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
|
||||||
|
Jonathon Shlens},
|
||||||
|
Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
|
||||||
|
Year = {2021},
|
||||||
|
}
|
||||||
|
|
||||||
|
Status:
|
||||||
|
This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
|
||||||
|
|
||||||
|
Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model
|
||||||
|
is extremely slow. Something isn't right. However, the models do appear to train and experimental
|
||||||
|
variants with attn in C4 and/or C5 stages are tolerable speed.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
|
""" Compute relative logits along one dimension
|
||||||
|
|
||||||
|
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
|
||||||
|
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: (batch, height, width, dim)
|
||||||
|
rel_k: (2 * window - 1, dim)
|
||||||
|
permute_mask: permute output dim according to this
|
||||||
|
"""
|
||||||
|
B, H, W, dim = q.shape
|
||||||
|
rel_size = rel_k.shape[0]
|
||||||
|
win_size = (rel_size + 1) // 2
|
||||||
|
|
||||||
|
x = (q @ rel_k.transpose(-1, -2))
|
||||||
|
x = x.reshape(-1, W, rel_size)
|
||||||
|
|
||||||
|
# pad to shift from relative to absolute indexing
|
||||||
|
x_pad = F.pad(x, [0, 1]).flatten(1)
|
||||||
|
x_pad = F.pad(x_pad, [0, rel_size - W])
|
||||||
|
|
||||||
|
# reshape and slice out the padded elements
|
||||||
|
x_pad = x_pad.reshape(-1, W + 1, rel_size)
|
||||||
|
x = x_pad[:, :W, win_size - 1:]
|
||||||
|
|
||||||
|
# reshape and tile
|
||||||
|
x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
|
||||||
|
return x.permute(permute_mask)
|
||||||
|
|
||||||
|
|
||||||
|
class PosEmbedRel(nn.Module):
|
||||||
|
""" Relative Position Embedding
|
||||||
|
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
|
||||||
|
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, block_size, win_size, dim_head, scale):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
block_size (int): block size
|
||||||
|
win_size (int): neighbourhood window size
|
||||||
|
dim_head (int): attention head dim
|
||||||
|
scale (float): scale factor (for init)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.block_size = block_size
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.scale = scale
|
||||||
|
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
||||||
|
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
||||||
|
|
||||||
|
def forward(self, q):
|
||||||
|
B, BB, HW, _ = q.shape
|
||||||
|
|
||||||
|
# relative logits in width dimension.
|
||||||
|
q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
|
||||||
|
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
|
||||||
|
|
||||||
|
# relative logits in height dimension.
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
|
||||||
|
|
||||||
|
rel_logits = rel_logits_h + rel_logits_w
|
||||||
|
rel_logits = rel_logits.reshape(B, BB, HW, -1)
|
||||||
|
return rel_logits
|
||||||
|
|
||||||
|
|
||||||
|
class HaloAttn(nn.Module):
|
||||||
|
""" Halo Attention
|
||||||
|
|
||||||
|
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
|
||||||
|
- https://arxiv.org/abs/2103.12731
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_out = dim_out or dim
|
||||||
|
assert dim_out % num_heads == 0
|
||||||
|
self.stride = stride
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.dim_qk = num_heads * dim_head
|
||||||
|
self.dim_v = dim_out
|
||||||
|
self.block_size = block_size
|
||||||
|
self.halo_size = halo_size
|
||||||
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||||
|
self.scale = self.dim_head ** -0.5
|
||||||
|
|
||||||
|
# FIXME not clear if this stride behaviour is what the paper intended, not really clear
|
||||||
|
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
||||||
|
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
||||||
|
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
|
||||||
|
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
|
||||||
|
|
||||||
|
self.pos_embed = PosEmbedRel(
|
||||||
|
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
assert H % self.block_size == 0 and W % self.block_size == 0
|
||||||
|
num_h_blocks = H // self.block_size
|
||||||
|
num_w_blocks = W // self.block_size
|
||||||
|
num_blocks = num_h_blocks * num_w_blocks
|
||||||
|
|
||||||
|
q = self.q(x)
|
||||||
|
q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
|
||||||
|
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
||||||
|
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
||||||
|
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
||||||
|
|
||||||
|
kv = self.kv(x)
|
||||||
|
# FIXME I 'think' this unfold does what I want it to, but I should investigate
|
||||||
|
k = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
||||||
|
k = k.reshape(
|
||||||
|
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
||||||
|
k, v = torch.split(k, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
||||||
|
|
||||||
|
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
||||||
|
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
||||||
|
|
||||||
|
attn_out = attn_logits.softmax(dim=-1)
|
||||||
|
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
||||||
|
attn_out = F.fold(
|
||||||
|
attn_out.reshape(B, -1, num_blocks),
|
||||||
|
(H // self.stride, W // self.stride),
|
||||||
|
kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride)
|
||||||
|
# B, dim_out, H // stride, W // stride
|
||||||
|
return attn_out
|
78
timm/models/layers/lambda_layer.py
Normal file
78
timm/models/layers/lambda_layer.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
""" Lambda Layer
|
||||||
|
|
||||||
|
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
|
||||||
|
- https://arxiv.org/abs/2102.08602
|
||||||
|
|
||||||
|
@misc{2102.08602,
|
||||||
|
Author = {Irwan Bello},
|
||||||
|
Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
|
||||||
|
Year = {2021},
|
||||||
|
}
|
||||||
|
|
||||||
|
Status:
|
||||||
|
This impl is a WIP. Code snippets in the paper were used as reference but
|
||||||
|
good chance some details are missing/wrong.
|
||||||
|
|
||||||
|
I've only implemented local lambda conv based pos embeddings.
|
||||||
|
|
||||||
|
For a PyTorch impl that includes other embedding options checkout
|
||||||
|
https://github.com/lucidrains/lambda-networks
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaLayer(nn.Module):
|
||||||
|
"""Lambda Layer w/ lambda conv position embedding
|
||||||
|
|
||||||
|
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
|
||||||
|
- https://arxiv.org/abs/2102.08602
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=5, qkv_bias=False):
|
||||||
|
super().__init__()
|
||||||
|
self.dim_out = dim_out or dim
|
||||||
|
self.dim_k = dim_head # query depth 'k'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
assert self.dim_out % num_heads == 0, ' should be divided by num_heads'
|
||||||
|
self.dim_v = self.dim_out // num_heads # value depth 'v'
|
||||||
|
self.r = r # relative position neighbourhood (lambda conv kernel size)
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(
|
||||||
|
dim,
|
||||||
|
num_heads * dim_head + dim_head + self.dim_v,
|
||||||
|
kernel_size=1, bias=qkv_bias)
|
||||||
|
self.norm_q = nn.BatchNorm2d(num_heads * dim_head)
|
||||||
|
self.norm_v = nn.BatchNorm2d(self.dim_v)
|
||||||
|
|
||||||
|
# NOTE currently only supporting the local lambda convolutions for positional
|
||||||
|
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0))
|
||||||
|
|
||||||
|
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
M = H * W
|
||||||
|
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = torch.split(qkv, [
|
||||||
|
self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1)
|
||||||
|
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K
|
||||||
|
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
|
||||||
|
k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M
|
||||||
|
|
||||||
|
content_lam = k @ v # B, K, V
|
||||||
|
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
|
||||||
|
|
||||||
|
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
|
||||||
|
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
|
||||||
|
position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
|
||||||
|
|
||||||
|
out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W
|
||||||
|
out = self.pool(out)
|
||||||
|
return out
|
Loading…
x
Reference in New Issue
Block a user