diff --git a/ppcls/arch/backbone/base/dbb_block.py b/ppcls/arch/backbone/base/dbb/dbb_block.py similarity index 76% rename from ppcls/arch/backbone/base/dbb_block.py rename to ppcls/arch/backbone/base/dbb/dbb_block.py index 3b66530a4..f38c5c257 100644 --- a/ppcls/arch/backbone/base/dbb_block.py +++ b/ppcls/arch/backbone/base/dbb/dbb_block.py @@ -18,6 +18,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F import numpy as np +from .dbb_transforms import * def conv_bn(in_channels, @@ -62,7 +63,6 @@ class IdentityBasedConv1x1(nn.Conv2D): for i in range(channels): id_value[i, i % input_dim, 0, 0] = 1 self.id_tensor = paddle.to_tensor(id_value) - # nn.init.zeros_(self.weight) self.weight.set_value(paddle.zeros_like(self.weight)) def forward(self, input): @@ -143,20 +143,21 @@ class DiverseBranchBlock(nn.Layer): stride=1, groups=1, act=None, + is_repped=False, + single_init=False, **kwargs): super().__init__() padding = (filter_size - 1) // 2 dilation = 1 - deploy = False - single_init = False + in_channels = num_channels out_channels = num_filters kernel_size = filter_size internal_channels_1x1_3x3 = None nonlinear = act - self.deploy = deploy + self.is_repped = is_repped if nonlinear is None: self.nonlinear = nn.Identity() @@ -168,7 +169,7 @@ class DiverseBranchBlock(nn.Layer): self.groups = groups assert padding == kernel_size // 2 - if deploy: + if is_repped: self.dbb_reparam = nn.Conv2D( in_channels=in_channels, out_channels=out_channels, @@ -268,8 +269,7 @@ class DiverseBranchBlock(nn.Layer): self.single_init() def forward(self, inputs): - - if hasattr(self, 'dbb_reparam'): + if self.is_repped: return self.nonlinear(self.dbb_reparam(inputs)) out = self.dbb_origin(inputs) @@ -293,3 +293,73 @@ class DiverseBranchBlock(nn.Layer): self.init_gamma(0.0) if hasattr(self, "dbb_origin"): paddle.nn.init.constant_(self.dbb_origin.bn.weight, 1.0) + + def get_equivalent_kernel_bias(self): + k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, + self.dbb_origin.bn) + + if hasattr(self, 'dbb_1x1'): + k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, + self.dbb_1x1.bn) + k_1x1 = transVI_multiscale(k_1x1, self.kernel_size) + else: + k_1x1, b_1x1 = 0, 0 + + if hasattr(self.dbb_1x1_kxk, 'idconv1'): + k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel() + else: + k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight + k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, + self.dbb_1x1_kxk.bn1) + k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn( + self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2) + k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk( + k_1x1_kxk_first, + b_1x1_kxk_first, + k_1x1_kxk_second, + b_1x1_kxk_second, + groups=self.groups) + + k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups) + k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg, + self.dbb_avg.avgbn) + if hasattr(self.dbb_avg, 'conv'): + k_1x1_avg_first, b_1x1_avg_first = transI_fusebn( + self.dbb_avg.conv.weight, self.dbb_avg.bn) + k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk( + k_1x1_avg_first, + b_1x1_avg_first, + k_1x1_avg_second, + b_1x1_avg_second, + groups=self.groups) + else: + k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second + + return transII_addbranch( + (k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), + (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged)) + + def re_parameterize(self): + if self.is_repped: + return + + kernel, bias = self.get_equivalent_kernel_bias() + self.dbb_reparam = nn.Conv2D( + in_channels=self.dbb_origin.conv._in_channels, + out_channels=self.dbb_origin.conv._out_channels, + kernel_size=self.dbb_origin.conv._kernel_size, + stride=self.dbb_origin.conv._stride, + padding=self.dbb_origin.conv._padding, + dilation=self.dbb_origin.conv._dilation, + groups=self.dbb_origin.conv._groups, + bias_attr=True) + + self.dbb_reparam.weight.set_value(kernel) + self.dbb_reparam.bias.set_value(bias) + + self.__delattr__('dbb_origin') + self.__delattr__('dbb_avg') + if hasattr(self, 'dbb_1x1'): + self.__delattr__('dbb_1x1') + self.__delattr__('dbb_1x1_kxk') + self.is_repped = True diff --git a/ppcls/arch/backbone/base/dbb/dbb_transforms.py b/ppcls/arch/backbone/base/dbb/dbb_transforms.py new file mode 100644 index 000000000..70f55fb09 --- /dev/null +++ b/ppcls/arch/backbone/base/dbb/dbb_transforms.py @@ -0,0 +1,73 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# reference: https://arxiv.org/abs/2103.13425, https://github.com/DingXiaoH/DiverseBranchBlock + +import numpy as np +import paddle +import paddle.nn.functional as F + + +def transI_fusebn(kernel, bn): + gamma = bn.weight + std = (bn._variance + bn._epsilon).sqrt() + return kernel * ( + (gamma / std).reshape([-1, 1, 1, 1])), bn.bias - bn._mean * gamma / std + + +def transII_addbranch(kernels, biases): + return sum(kernels), sum(biases) + + +def transIII_1x1_kxk(k1, b1, k2, b2, groups): + if groups == 1: + k = F.conv2d(k2, k1.transpose([1, 0, 2, 3])) + b_hat = (k2 * b1.reshape([1, -1, 1, 1])).sum((1, 2, 3)) + else: + k_slices = [] + b_slices = [] + k1_T = k1.transpose([1, 0, 2, 3]) + k1_group_width = k1.shape[0] // groups + k2_group_width = k2.shape[0] // groups + for g in range(groups): + k1_T_slice = k1_T[:, g * k1_group_width:(g + 1) * + k1_group_width, :, :] + k2_slice = k2[g * k2_group_width:(g + 1) * k2_group_width, :, :, :] + k_slices.append(F.conv2d(k2_slice, k1_T_slice)) + b_slices.append((k2_slice * b1[g * k1_group_width:( + g + 1) * k1_group_width].reshape([1, -1, 1, 1])).sum((1, 2, 3 + ))) + k, b_hat = transIV_depthconcat(k_slices, b_slices) + return k, b_hat + b2 + + +def transIV_depthconcat(kernels, biases): + return paddle.cat(kernels, axis=0), paddle.cat(biases) + + +def transV_avg(channels, kernel_size, groups): + input_dim = channels // groups + k = paddle.zeros((channels, input_dim, kernel_size, kernel_size)) + k[np.arange(channels), np.tile(np.arange(input_dim), + groups), :, :] = 1.0 / kernel_size**2 + return k + + +# This has not been tested with non-square kernels (kernel.shape[2] != kernel.shape[3]) nor even-size kernels +def transVI_multiscale(kernel, target_kernel_size): + H_pixels_to_pad = (target_kernel_size - kernel.shape[2]) // 2 + W_pixels_to_pad = (target_kernel_size - kernel.shape[3]) // 2 + return F.pad( + kernel, + [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]) diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 413a728d8..235268cfe 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -28,7 +28,7 @@ import math from ....utils import logger from ..base.theseus_layer import TheseusLayer -from ..base.dbb_block import DiverseBranchBlock +from ..base.dbb.dbb_block import DiverseBranchBlock from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url MODEL_URLS = {