mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add mobileone and update repvgg
This commit is contained in:
parent
3055411c1b
commit
7fd3674d0d
@ -91,6 +91,27 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
|
||||
return bcfg
|
||||
|
||||
|
||||
def _mobileone_bcfg(d=(2, 8, 10, 1), wf=(1., 1., 1., 1.), se_blocks=(), num_conv_branches=1):
|
||||
c = (64, 128, 256, 512)
|
||||
prev_c = min(64, c[0] * wf[0])
|
||||
se_blocks = se_blocks or (0,) * len(d)
|
||||
bcfg = []
|
||||
for d, c, w, se in zip(d, c, wf, se_blocks):
|
||||
scfg = []
|
||||
for i in range(d):
|
||||
out_c = c * w
|
||||
bk = dict(num_conv_branches=num_conv_branches)
|
||||
ak = {}
|
||||
if i >= d - se:
|
||||
ak['attn_layer'] = 'se'
|
||||
scfg += [ByoBlockCfg(type='one', d=1, c=prev_c, gs=1, block_kwargs=bk, **ak)] # depthwise block
|
||||
scfg += [ByoBlockCfg(
|
||||
type='one', d=1, c=out_c, gs=0, block_kwargs=dict(kernel_size=1, **bk), **ak)] # pointwise block
|
||||
prev_c = out_c
|
||||
bcfg += [scfg]
|
||||
return bcfg
|
||||
|
||||
|
||||
def interleave_blocks(
|
||||
types: Tuple[str, str], d,
|
||||
every: Union[int, List[int]] = 1,
|
||||
@ -468,7 +489,7 @@ class RepVggBlock(nn.Module):
|
||||
super(RepVggBlock, self).__init__()
|
||||
layers = layers or LayerFn()
|
||||
groups = num_groups(group_size, in_chs)
|
||||
|
||||
#self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME temp for remapping
|
||||
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
||||
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
||||
self.conv_kxk = layers.conv_norm_act(
|
||||
@ -501,6 +522,215 @@ class RepVggBlock(nn.Module):
|
||||
return self.act(x)
|
||||
|
||||
|
||||
class MobileOneBlock(nn.Module):
|
||||
""" MobileOne building block.
|
||||
|
||||
This block has a multi-branched architecture at train-time
|
||||
and plain-CNN style architecture at inference time
|
||||
For more details, please refer to our paper:
|
||||
`An Improved One millisecond Mobile Backbone` -
|
||||
https://arxiv.org/pdf/2206.04040.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: Tuple[int, int] = (1, 1),
|
||||
bottle_ratio: float = 1.0, # unused
|
||||
group_size: Optional[int] = None,
|
||||
downsample: str = '', # unused
|
||||
inference_mode: bool = False,
|
||||
num_conv_branches: int = 1,
|
||||
layers: LayerFn = None,
|
||||
drop_block: Callable = None,
|
||||
drop_path_rate: float = 0.,
|
||||
) -> None:
|
||||
""" Construct a MobileOneBlock module.
|
||||
|
||||
:param in_chs: Number of channels in the input.
|
||||
:param out_chs: Number of channels produced by the block.
|
||||
:param kernel_size: Size of the convolution kernel.
|
||||
:param stride: Stride size.
|
||||
:param dilation: Kernel dilation factor.
|
||||
:param groups: Group number.
|
||||
:param inference_mode: If True, instantiates model in inference mode.
|
||||
:param use_se: Whether to use SE-ReLU activations.
|
||||
:param num_conv_branches: Number of linear conv branches.
|
||||
"""
|
||||
super(MobileOneBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channels = in_chs
|
||||
self.out_channels = out_chs
|
||||
self.num_conv_branches = num_conv_branches
|
||||
layers = layers or LayerFn()
|
||||
groups = num_groups(group_size, in_chs)
|
||||
|
||||
# Check if SE-ReLU is requested
|
||||
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) # FIXME move after remap
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=in_chs,
|
||||
out_channels=out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=True)
|
||||
else:
|
||||
self.reparam_conv = None
|
||||
|
||||
# Re-parameterizable skip connection
|
||||
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
||||
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
||||
|
||||
# Re-parameterizable conv branches
|
||||
convs = list()
|
||||
for _ in range(self.num_conv_branches):
|
||||
convs.append(layers.conv_norm_act(
|
||||
in_chs, out_chs, kernel_size=kernel_size,
|
||||
stride=stride, groups=groups, apply_act=False))
|
||||
self.conv_kxk = nn.ModuleList(convs)
|
||||
|
||||
# Re-parameterizable scale branch
|
||||
self.conv_scale = None
|
||||
if kernel_size > 1:
|
||||
self.conv_scale = layers.conv_norm_act(
|
||||
in_chs, out_chs, kernel_size=1,
|
||||
stride=stride, groups=groups, apply_act=False)
|
||||
|
||||
self.act = layers.act(inplace=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Apply forward pass. """
|
||||
# Inference mode forward pass.
|
||||
if self.reparam_conv is not None:
|
||||
return self.act(self.attn(self.reparam_conv(x)))
|
||||
|
||||
# Multi-branched train-time forward pass.
|
||||
# Skip branch output
|
||||
identity_out = 0
|
||||
if self.identity is not None:
|
||||
identity_out = self.identity(x)
|
||||
|
||||
# Scale branch output
|
||||
scale_out = 0
|
||||
if self.conv_scale is not None:
|
||||
scale_out = self.conv_scale(x)
|
||||
|
||||
# Other branches
|
||||
out = scale_out + identity_out
|
||||
for ix in range(self.num_conv_branches):
|
||||
out += self.conv_kxk[ix](x)
|
||||
|
||||
return self.act(self.attn(out))
|
||||
|
||||
def reparameterize(self):
|
||||
""" Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
||||
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
||||
architecture used at training time to obtain a plain CNN-like structure
|
||||
for inference.
|
||||
"""
|
||||
if self.reparam_conv is not None:
|
||||
return
|
||||
|
||||
kernel, bias = self._get_kernel_bias()
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.conv_kxk[0].conv.in_channels,
|
||||
out_channels=self.conv_kxk[0].conv.out_channels,
|
||||
kernel_size=self.conv_kxk[0].conv.kernel_size,
|
||||
stride=self.conv_kxk[0].conv.stride,
|
||||
padding=self.conv_kxk[0].conv.padding,
|
||||
dilation=self.conv_kxk[0].conv.dilation,
|
||||
groups=self.conv_kxk[0].conv.groups,
|
||||
bias=True)
|
||||
self.reparam_conv.weight.data = kernel
|
||||
self.reparam_conv.bias.data = bias
|
||||
|
||||
# Delete un-used branches
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__('conv_kxk')
|
||||
self.__delattr__('conv_scale')
|
||||
if hasattr(self, 'identity'):
|
||||
self.__delattr__('identity')
|
||||
|
||||
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Method to obtain re-parameterized kernel and bias.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
||||
|
||||
:return: Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
# get weights and bias of scale branch
|
||||
kernel_scale = 0
|
||||
bias_scale = 0
|
||||
if self.conv_scale is not None:
|
||||
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
|
||||
# Pad scale branch kernel to match conv branch kernel size.
|
||||
pad = self.kernel_size // 2
|
||||
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
||||
|
||||
# get weights and bias of skip branch
|
||||
kernel_identity = 0
|
||||
bias_identity = 0
|
||||
if self.identity is not None:
|
||||
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
|
||||
|
||||
# get weights and bias of conv branches
|
||||
kernel_conv = 0
|
||||
bias_conv = 0
|
||||
for ix in range(self.num_conv_branches):
|
||||
_kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
|
||||
kernel_conv += _kernel
|
||||
bias_conv += _bias
|
||||
|
||||
kernel_final = kernel_conv + kernel_scale + kernel_identity
|
||||
bias_final = bias_conv + bias_scale + bias_identity
|
||||
return kernel_final, bias_final
|
||||
|
||||
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Method to fuse batchnorm layer with preceeding conv layer.
|
||||
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
||||
|
||||
:param branch:
|
||||
:return: Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
if isinstance(branch, ConvNormAct):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
gamma = branch.bn.weight
|
||||
beta = branch.bn.bias
|
||||
eps = branch.bn.eps
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
if not hasattr(self, 'id_tensor'):
|
||||
input_dim = self.in_channels // self.groups
|
||||
kernel_value = torch.zeros(
|
||||
(self.in_channels,
|
||||
input_dim,
|
||||
self.kernel_size,
|
||||
self.kernel_size),
|
||||
dtype=branch.weight.dtype,
|
||||
device=branch.weight.device)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
|
||||
self.id_tensor = kernel_value
|
||||
kernel = self.id_tensor
|
||||
running_mean = branch.running_mean
|
||||
running_var = branch.running_var
|
||||
gamma = branch.weight
|
||||
beta = branch.bias
|
||||
eps = branch.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
|
||||
class SelfAttnBlock(nn.Module):
|
||||
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
|
||||
"""
|
||||
@ -576,6 +806,7 @@ _block_registry = dict(
|
||||
dark=DarkBlock,
|
||||
edge=EdgeBlock,
|
||||
rep=RepVggBlock,
|
||||
one=MobileOneBlock,
|
||||
self_attn=SelfAttnBlock,
|
||||
)
|
||||
|
||||
@ -657,7 +888,7 @@ def create_byob_stem(
|
||||
layers: LayerFn = None,
|
||||
):
|
||||
layers = layers or LayerFn()
|
||||
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
|
||||
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', 'one', '7x7', '3x3')
|
||||
if 'quad' in stem_type:
|
||||
# based on NFNet stem, stack of 4 3x3 convs
|
||||
num_act = 2 if 'quad2' in stem_type else None
|
||||
@ -670,6 +901,8 @@ def create_byob_stem(
|
||||
stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
|
||||
elif 'rep' in stem_type:
|
||||
stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
|
||||
elif 'one' in stem_type:
|
||||
stem = MobileOneBlock(in_chs, out_chs, kernel_size=3, stride=2, layers=layers)
|
||||
elif '7x7' in stem_type:
|
||||
# 7x7 stem conv as in ResNet
|
||||
if pool_type:
|
||||
@ -1033,6 +1266,13 @@ model_cfgs = dict(
|
||||
stem_type='rep',
|
||||
stem_chs=64,
|
||||
),
|
||||
repvgg_d2se=ByoModelCfg(
|
||||
blocks=_rep_vgg_bcfg(d=(8, 14, 24, 1), wf=(2.5, 2.5, 2.5, 5.)),
|
||||
stem_type='rep',
|
||||
stem_chs=64,
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=0.0625, rd_divisor=1),
|
||||
),
|
||||
|
||||
# 4 x conv stem w/ 2 act, no maxpool, 2,4,6,4 repeats, group size 32 in first 3 blocks
|
||||
# DW convs in last block, 2048 pre-FC, silu act
|
||||
@ -1375,6 +1615,32 @@ model_cfgs = dict(
|
||||
attn_kwargs=dict(rd_ratio=0.25),
|
||||
block_kwargs=dict(bottle_in=True, linear_out=True),
|
||||
),
|
||||
|
||||
mobileone_s0=ByoModelCfg(
|
||||
blocks=_mobileone_bcfg(wf=(0.75, 1.0, 1.0, 2.), num_conv_branches=4),
|
||||
stem_type='one',
|
||||
stem_chs=48,
|
||||
),
|
||||
mobileone_s1=ByoModelCfg(
|
||||
blocks=_mobileone_bcfg(wf=(1.5, 1.5, 2.0, 2.5)),
|
||||
stem_type='one',
|
||||
stem_chs=64,
|
||||
),
|
||||
mobileone_s2=ByoModelCfg(
|
||||
blocks=_mobileone_bcfg(wf=(1.5, 2.0, 2.5, 4.0)),
|
||||
stem_type='one',
|
||||
stem_chs=64,
|
||||
),
|
||||
mobileone_s3=ByoModelCfg(
|
||||
blocks=_mobileone_bcfg(wf=(2.0, 2.5, 3.0, 4.0)),
|
||||
stem_type='one',
|
||||
stem_chs=64,
|
||||
),
|
||||
mobileone_s4=ByoModelCfg(
|
||||
blocks=_mobileone_bcfg(wf=(3.0, 3.5, 3.5, 4.0), se_blocks=(0, 0, 5, 1)),
|
||||
stem_type='one',
|
||||
stem_chs=64,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -1630,6 +1896,14 @@ def repvgg_b3g4(pretrained=False, **kwargs) -> ByobNet:
|
||||
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def repvgg_d2se(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" RepVGG-D2se
|
||||
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
||||
"""
|
||||
return _create_byobnet('repvgg_d2se', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet51q(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
@ -1782,3 +2056,38 @@ def regnetz_d8_evos(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobileone_s0(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('mobileone_s0', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobileone_s1(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('mobileone_s1', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobileone_s2(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('mobileone_s2', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobileone_s3(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('mobileone_s3', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user