mirror of https://github.com/open-mmlab/mmyolo.git
add yolo_bricks (#21)
parent
c2c4d37f1e
commit
7e90343d85
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ema import ExpMomentumEMA
|
||||
from .yolo_bricks import RepStageBlock, RepVGGBlock, SPPFBottleneck
|
||||
|
||||
__all__ = ['SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA']
|
|
@ -0,0 +1,297 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
if digit_version(torch.__version__) >= digit_version('1.7.0'):
|
||||
MODELS.register_module(module=nn.SiLU, name='SiLU')
|
||||
else:
|
||||
|
||||
class SiLU(nn.Module):
|
||||
|
||||
def __init__(self, inplace=True):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, inputs) -> torch.Tensor:
|
||||
return inputs * torch.sigmoid(inputs)
|
||||
|
||||
MODELS.register_module(module=SiLU, name='SiLU')
|
||||
|
||||
|
||||
class SPPFBottleneck(BaseModule):
|
||||
"""Spatial pyramid pooling - Fast (SPPF) layer for
|
||||
YOLOv5 and YOLOX by Glenn Jocher
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
kernel_sizes (int, tuple[int]): Sequential or number of kernel
|
||||
sizes of pooling layers. Defaults to 5.
|
||||
conv_cfg (dict): Config dict for convolution layer. Defaults to None.
|
||||
which means using conv2d. Defaults to None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_sizes: Union[int, Tuple[int]] = 5,
|
||||
conv_cfg: ConfigType = None,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(init_cfg)
|
||||
mid_channels = in_channels // 2
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
1,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.kernel_sizes = kernel_sizes
|
||||
if isinstance(kernel_sizes, int):
|
||||
self.poolings = nn.MaxPool2d(
|
||||
kernel_size=kernel_sizes, stride=1, padding=kernel_sizes // 2)
|
||||
else:
|
||||
self.poolings = nn.ModuleList([
|
||||
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
|
||||
for ks in kernel_sizes
|
||||
])
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
mid_channels * 4,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward process."""
|
||||
x = self.conv1(x)
|
||||
if isinstance(self.kernel_sizes, int):
|
||||
y1 = self.poolings(x)
|
||||
y2 = self.poolings(y1)
|
||||
x = torch.cat([x, y1, y2, self.poolings(y2)], dim=1)
|
||||
else:
|
||||
x = torch.cat(
|
||||
[x] + [pooling(x) for pooling in self.poolings], dim=1)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepVGGBlock(nn.Module):
|
||||
"""RepVGGBlock is a basic rep-style block, including training and deploy
|
||||
status This code is based on
|
||||
https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple): Stride of the convolution. Default: 1
|
||||
padding (int, tuple): Padding added to all four sides of
|
||||
the input. Default: 1
|
||||
dilation (int or tuple): Spacing between kernel elements. Default: 1
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default: 1
|
||||
padding_mode (string, optional): Default: 'zeros'
|
||||
deploy (bool): Whether in deploy mode. Default: False
|
||||
use_se (bool): Whether to use se. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int]] = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
padding: Union[int, Tuple[int]] = 1,
|
||||
dilation: Union[int, Tuple[int]] = 1,
|
||||
groups: Optional[int] = 1,
|
||||
padding_mode: Optional[str] = 'zeros',
|
||||
deploy: bool = False,
|
||||
use_se: bool = False):
|
||||
super().__init__()
|
||||
self.deploy = deploy
|
||||
self.groups = groups
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
assert kernel_size == 3
|
||||
assert padding == 1
|
||||
|
||||
padding_11 = padding - kernel_size // 2
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
|
||||
if use_se:
|
||||
raise NotImplementedError('se block not supported yet')
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
if deploy:
|
||||
self.rbr_reparam = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
else:
|
||||
self.rbr_identity = nn.BatchNorm2d(
|
||||
num_features=in_channels, momentum=0.03, eps=0.001
|
||||
) if out_channels == in_channels and stride == 1 else None
|
||||
self.rbr_dense = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=None)
|
||||
self.rbr_1x1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=padding_11,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward process."""
|
||||
if hasattr(self, 'rbr_reparam'):
|
||||
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
|
||||
|
||||
if self.rbr_identity is None:
|
||||
id_out = 0
|
||||
else:
|
||||
id_out = self.rbr_identity(inputs)
|
||||
|
||||
return self.nonlinearity(
|
||||
self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
|
||||
|
||||
def get_equivalent_kernel_bias(self):
|
||||
"""Derives the equivalent kernel and bias in a differentiable way."""
|
||||
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
||||
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
||||
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
||||
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
|
||||
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
||||
|
||||
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
||||
"""Pad 1x1 tensor to 3x3."""
|
||||
if kernel1x1 is None:
|
||||
return 0
|
||||
else:
|
||||
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
||||
|
||||
def _fuse_bn_tensor(self,
|
||||
branch: nn.Module) -> Tuple[np.ndarray, torch.Tensor]:
|
||||
"""Derives the equivalent kernel and bias of a specific branch layer.
|
||||
|
||||
Args:
|
||||
branch (nn.Module): The layer that needs to be equivalently
|
||||
transformed, which can be nn.Sequential or nn.Batchnorm2d
|
||||
|
||||
Returns:
|
||||
tuple: Equivalent kernel and bias
|
||||
"""
|
||||
if branch is None:
|
||||
return 0, 0
|
||||
if isinstance(branch, nn.Sequential):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
gamma = branch.bn.weight
|
||||
beta = branch.bn.bias
|
||||
eps = branch.bn.eps
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
if not hasattr(self, 'id_tensor'):
|
||||
input_dim = self.in_channels // self.groups
|
||||
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
|
||||
dtype=np.float32)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[i, i % input_dim, 1, 1] = 1
|
||||
self.id_tensor = torch.from_numpy(kernel_value).to(
|
||||
branch.weight.device)
|
||||
kernel = self.id_tensor
|
||||
running_mean = branch.running_mean
|
||||
running_var = branch.running_var
|
||||
gamma = branch.weight
|
||||
beta = branch.bias
|
||||
eps = branch.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
def switch_to_deploy(self):
|
||||
"""Switch to deploy mode."""
|
||||
if hasattr(self, 'rbr_reparam'):
|
||||
return
|
||||
kernel, bias = self.get_equivalent_kernel_bias()
|
||||
self.rbr_reparam = nn.Conv2d(
|
||||
in_channels=self.rbr_dense.conv.in_channels,
|
||||
out_channels=self.rbr_dense.conv.out_channels,
|
||||
kernel_size=self.rbr_dense.conv.kernel_size,
|
||||
stride=self.rbr_dense.conv.stride,
|
||||
padding=self.rbr_dense.conv.padding,
|
||||
dilation=self.rbr_dense.conv.dilation,
|
||||
groups=self.rbr_dense.conv.groups,
|
||||
bias=True)
|
||||
self.rbr_reparam.weight.data = kernel
|
||||
self.rbr_reparam.bias.data = bias
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__('rbr_dense')
|
||||
self.__delattr__('rbr_1x1')
|
||||
if hasattr(self, 'rbr_identity'):
|
||||
self.__delattr__('rbr_identity')
|
||||
if hasattr(self, 'id_tensor'):
|
||||
self.__delattr__('id_tensor')
|
||||
self.deploy = True
|
||||
|
||||
|
||||
class RepStageBlock(nn.Module):
|
||||
"""RepStageBlock is a stage block with rep-style basic block."""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
n: int = 1,
|
||||
block: nn.Module = RepVGGBlock):
|
||||
super().__init__()
|
||||
self.conv1 = block(in_channels, out_channels)
|
||||
self.block = nn.Sequential(*(block(out_channels, out_channels)
|
||||
for _ in range(n - 1))) if n > 1 else None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv1(x)
|
||||
if self.block is not None:
|
||||
x = self.block(x)
|
||||
return x
|
Loading…
Reference in New Issue