support re_parameterize
parent
f82871b1f8
commit
d4d3d01384
|
@ -18,6 +18,7 @@ import paddle
|
|||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
from .dbb_transforms import *
|
||||
|
||||
|
||||
def conv_bn(in_channels,
|
||||
|
@ -62,7 +63,6 @@ class IdentityBasedConv1x1(nn.Conv2D):
|
|||
for i in range(channels):
|
||||
id_value[i, i % input_dim, 0, 0] = 1
|
||||
self.id_tensor = paddle.to_tensor(id_value)
|
||||
# nn.init.zeros_(self.weight)
|
||||
self.weight.set_value(paddle.zeros_like(self.weight))
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -143,20 +143,21 @@ class DiverseBranchBlock(nn.Layer):
|
|||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
is_repped=False,
|
||||
single_init=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
padding = (filter_size - 1) // 2
|
||||
dilation = 1
|
||||
deploy = False
|
||||
single_init = False
|
||||
|
||||
in_channels = num_channels
|
||||
out_channels = num_filters
|
||||
kernel_size = filter_size
|
||||
internal_channels_1x1_3x3 = None
|
||||
nonlinear = act
|
||||
|
||||
self.deploy = deploy
|
||||
self.is_repped = is_repped
|
||||
|
||||
if nonlinear is None:
|
||||
self.nonlinear = nn.Identity()
|
||||
|
@ -168,7 +169,7 @@ class DiverseBranchBlock(nn.Layer):
|
|||
self.groups = groups
|
||||
assert padding == kernel_size // 2
|
||||
|
||||
if deploy:
|
||||
if is_repped:
|
||||
self.dbb_reparam = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
|
@ -268,8 +269,7 @@ class DiverseBranchBlock(nn.Layer):
|
|||
self.single_init()
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
if hasattr(self, 'dbb_reparam'):
|
||||
if self.is_repped:
|
||||
return self.nonlinear(self.dbb_reparam(inputs))
|
||||
|
||||
out = self.dbb_origin(inputs)
|
||||
|
@ -293,3 +293,73 @@ class DiverseBranchBlock(nn.Layer):
|
|||
self.init_gamma(0.0)
|
||||
if hasattr(self, "dbb_origin"):
|
||||
paddle.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
|
||||
|
||||
def get_equivalent_kernel_bias(self):
|
||||
k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
|
||||
self.dbb_origin.bn)
|
||||
|
||||
if hasattr(self, 'dbb_1x1'):
|
||||
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
|
||||
self.dbb_1x1.bn)
|
||||
k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
|
||||
else:
|
||||
k_1x1, b_1x1 = 0, 0
|
||||
|
||||
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
|
||||
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
|
||||
else:
|
||||
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
|
||||
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first,
|
||||
self.dbb_1x1_kxk.bn1)
|
||||
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(
|
||||
self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2)
|
||||
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(
|
||||
k_1x1_kxk_first,
|
||||
b_1x1_kxk_first,
|
||||
k_1x1_kxk_second,
|
||||
b_1x1_kxk_second,
|
||||
groups=self.groups)
|
||||
|
||||
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
|
||||
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg,
|
||||
self.dbb_avg.avgbn)
|
||||
if hasattr(self.dbb_avg, 'conv'):
|
||||
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(
|
||||
self.dbb_avg.conv.weight, self.dbb_avg.bn)
|
||||
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(
|
||||
k_1x1_avg_first,
|
||||
b_1x1_avg_first,
|
||||
k_1x1_avg_second,
|
||||
b_1x1_avg_second,
|
||||
groups=self.groups)
|
||||
else:
|
||||
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
|
||||
|
||||
return transII_addbranch(
|
||||
(k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
|
||||
(b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))
|
||||
|
||||
def re_parameterize(self):
|
||||
if self.is_repped:
|
||||
return
|
||||
|
||||
kernel, bias = self.get_equivalent_kernel_bias()
|
||||
self.dbb_reparam = nn.Conv2D(
|
||||
in_channels=self.dbb_origin.conv._in_channels,
|
||||
out_channels=self.dbb_origin.conv._out_channels,
|
||||
kernel_size=self.dbb_origin.conv._kernel_size,
|
||||
stride=self.dbb_origin.conv._stride,
|
||||
padding=self.dbb_origin.conv._padding,
|
||||
dilation=self.dbb_origin.conv._dilation,
|
||||
groups=self.dbb_origin.conv._groups,
|
||||
bias_attr=True)
|
||||
|
||||
self.dbb_reparam.weight.set_value(kernel)
|
||||
self.dbb_reparam.bias.set_value(bias)
|
||||
|
||||
self.__delattr__('dbb_origin')
|
||||
self.__delattr__('dbb_avg')
|
||||
if hasattr(self, 'dbb_1x1'):
|
||||
self.__delattr__('dbb_1x1')
|
||||
self.__delattr__('dbb_1x1_kxk')
|
||||
self.is_repped = True
|
|
@ -0,0 +1,73 @@
|
|||
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# reference: https://arxiv.org/abs/2103.13425, https://github.com/DingXiaoH/DiverseBranchBlock
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def transI_fusebn(kernel, bn):
|
||||
gamma = bn.weight
|
||||
std = (bn._variance + bn._epsilon).sqrt()
|
||||
return kernel * (
|
||||
(gamma / std).reshape([-1, 1, 1, 1])), bn.bias - bn._mean * gamma / std
|
||||
|
||||
|
||||
def transII_addbranch(kernels, biases):
|
||||
return sum(kernels), sum(biases)
|
||||
|
||||
|
||||
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
|
||||
if groups == 1:
|
||||
k = F.conv2d(k2, k1.transpose([1, 0, 2, 3]))
|
||||
b_hat = (k2 * b1.reshape([1, -1, 1, 1])).sum((1, 2, 3))
|
||||
else:
|
||||
k_slices = []
|
||||
b_slices = []
|
||||
k1_T = k1.transpose([1, 0, 2, 3])
|
||||
k1_group_width = k1.shape[0] // groups
|
||||
k2_group_width = k2.shape[0] // groups
|
||||
for g in range(groups):
|
||||
k1_T_slice = k1_T[:, g * k1_group_width:(g + 1) *
|
||||
k1_group_width, :, :]
|
||||
k2_slice = k2[g * k2_group_width:(g + 1) * k2_group_width, :, :, :]
|
||||
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
|
||||
b_slices.append((k2_slice * b1[g * k1_group_width:(
|
||||
g + 1) * k1_group_width].reshape([1, -1, 1, 1])).sum((1, 2, 3
|
||||
)))
|
||||
k, b_hat = transIV_depthconcat(k_slices, b_slices)
|
||||
return k, b_hat + b2
|
||||
|
||||
|
||||
def transIV_depthconcat(kernels, biases):
|
||||
return paddle.cat(kernels, axis=0), paddle.cat(biases)
|
||||
|
||||
|
||||
def transV_avg(channels, kernel_size, groups):
|
||||
input_dim = channels // groups
|
||||
k = paddle.zeros((channels, input_dim, kernel_size, kernel_size))
|
||||
k[np.arange(channels), np.tile(np.arange(input_dim),
|
||||
groups), :, :] = 1.0 / kernel_size**2
|
||||
return k
|
||||
|
||||
|
||||
# This has not been tested with non-square kernels (kernel.shape[2] != kernel.shape[3]) nor even-size kernels
|
||||
def transVI_multiscale(kernel, target_kernel_size):
|
||||
H_pixels_to_pad = (target_kernel_size - kernel.shape[2]) // 2
|
||||
W_pixels_to_pad = (target_kernel_size - kernel.shape[3]) // 2
|
||||
return F.pad(
|
||||
kernel,
|
||||
[H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
|
|
@ -28,7 +28,7 @@ import math
|
|||
|
||||
from ....utils import logger
|
||||
from ..base.theseus_layer import TheseusLayer
|
||||
from ..base.dbb_block import DiverseBranchBlock
|
||||
from ..base.dbb.dbb_block import DiverseBranchBlock
|
||||
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
|
|
Loading…
Reference in New Issue