# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # reference: https://arxiv.org/abs/2103.13425, https://github.com/DingXiaoH/DiverseBranchBlock import paddle import paddle.nn as nn import paddle.nn.functional as F import numpy as np from .dbb_transforms import * def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros'): conv_layer = nn.Conv2D( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias_attr=False, padding_mode=padding_mode) bn_layer = nn.BatchNorm2D(num_features=out_channels) se = nn.Sequential() se.add_sublayer('conv', conv_layer) se.add_sublayer('bn', bn_layer) return se class IdentityBasedConv1x1(nn.Conv2D): def __init__(self, channels, groups=1): super(IdentityBasedConv1x1, self).__init__( in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias_attr=False) assert channels % groups == 0 input_dim = channels // groups id_value = np.zeros((channels, input_dim, 1, 1)) for i in range(channels): id_value[i, i % input_dim, 0, 0] = 1 self.id_tensor = paddle.to_tensor(id_value) self.weight.set_value(paddle.zeros_like(self.weight)) def forward(self, input): kernel = self.weight + self.id_tensor result = F.conv2d( input, kernel, None, stride=1, padding=0, dilation=self._dilation, groups=self._groups) return result def get_actual_kernel(self): return self.weight + self.id_tensor class BNAndPad(nn.Layer): def __init__(self, pad_pixels, num_features, epsilon=1e-5, momentum=0.1, last_conv_bias=None, bn=nn.BatchNorm2D): super().__init__() self.bn = bn(num_features, momentum=momentum, epsilon=epsilon) self.pad_pixels = pad_pixels self.last_conv_bias = last_conv_bias def forward(self, input): output = self.bn(input) if self.pad_pixels > 0: bias = -self.bn._mean if self.last_conv_bias is not None: bias += self.last_conv_bias pad_values = self.bn.bias + self.bn.weight * ( bias / paddle.sqrt(self.bn._variance + self.bn._epsilon)) ''' pad ''' # TODO: n,h,w,c format is not supported yet n, c, h, w = output.shape values = pad_values.reshape([1, -1, 1, 1]) w_values = values.expand([n, -1, self.pad_pixels, w]) x = paddle.concat([w_values, output, w_values], axis=2) h = h + self.pad_pixels * 2 h_values = values.expand([n, -1, h, self.pad_pixels]) x = paddle.concat([h_values, x, h_values], axis=3) output = x return output @property def weight(self): return self.bn.weight @property def bias(self): return self.bn.bias @property def _mean(self): return self.bn._mean @property def _variance(self): return self.bn._variance @property def _epsilon(self): return self.bn._epsilon class DiverseBranchBlock(nn.Layer): def __init__(self, num_channels, num_filters, filter_size, stride=1, groups=1, act=None, is_repped=False, single_init=False, **kwargs): super().__init__() padding = (filter_size - 1) // 2 dilation = 1 in_channels = num_channels out_channels = num_filters kernel_size = filter_size internal_channels_1x1_3x3 = None nonlinear = act self.is_repped = is_repped if nonlinear is None: self.nonlinear = nn.Identity() else: self.nonlinear = nn.ReLU() self.kernel_size = kernel_size self.out_channels = out_channels self.groups = groups assert padding == kernel_size // 2 if is_repped: self.dbb_reparam = nn.Conv2D( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias_attr=True) else: self.dbb_origin = conv_bn( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups) self.dbb_avg = nn.Sequential() if groups < out_channels: self.dbb_avg.add_sublayer( 'conv', nn.Conv2D( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=groups, bias_attr=False)) self.dbb_avg.add_sublayer( 'bn', BNAndPad( pad_pixels=padding, num_features=out_channels)) self.dbb_avg.add_sublayer( 'avg', nn.AvgPool2D( kernel_size=kernel_size, stride=stride, padding=0)) self.dbb_1x1 = conv_bn( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, groups=groups) else: self.dbb_avg.add_sublayer( 'avg', nn.AvgPool2D( kernel_size=kernel_size, stride=stride, padding=padding)) self.dbb_avg.add_sublayer('avgbn', nn.BatchNorm2D(out_channels)) if internal_channels_1x1_3x3 is None: internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels self.dbb_1x1_kxk = nn.Sequential() if internal_channels_1x1_3x3 == in_channels: self.dbb_1x1_kxk.add_sublayer( 'idconv1', IdentityBasedConv1x1( channels=in_channels, groups=groups)) else: self.dbb_1x1_kxk.add_sublayer( 'conv1', nn.Conv2D( in_channels=in_channels, out_channels=internal_channels_1x1_3x3, kernel_size=1, stride=1, padding=0, groups=groups, bias_attr=False)) self.dbb_1x1_kxk.add_sublayer( 'bn1', BNAndPad( pad_pixels=padding, num_features=internal_channels_1x1_3x3)) self.dbb_1x1_kxk.add_sublayer( 'conv2', nn.Conv2D( in_channels=internal_channels_1x1_3x3, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias_attr=False)) self.dbb_1x1_kxk.add_sublayer('bn2', nn.BatchNorm2D(out_channels)) # The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases. if single_init: # Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting. self.single_init() def forward(self, inputs): if self.is_repped: return self.nonlinear(self.dbb_reparam(inputs)) out = self.dbb_origin(inputs) if hasattr(self, 'dbb_1x1'): out += self.dbb_1x1(inputs) out += self.dbb_avg(inputs) out += self.dbb_1x1_kxk(inputs) return self.nonlinear(out) def init_gamma(self, gamma_value): if hasattr(self, "dbb_origin"): paddle.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value) if hasattr(self, "dbb_1x1"): paddle.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value) if hasattr(self, "dbb_avg"): paddle.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value) if hasattr(self, "dbb_1x1_kxk"): paddle.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value) def single_init(self): self.init_gamma(0.0) if hasattr(self, "dbb_origin"): paddle.nn.init.constant_(self.dbb_origin.bn.weight, 1.0) def get_equivalent_kernel_bias(self): k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn) if hasattr(self, 'dbb_1x1'): k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn) k_1x1 = transVI_multiscale(k_1x1, self.kernel_size) else: k_1x1, b_1x1 = 0, 0 if hasattr(self.dbb_1x1_kxk, 'idconv1'): k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel() else: k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1) k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn( self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2) k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk( k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups) k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups) k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg, self.dbb_avg.avgbn) if hasattr(self.dbb_avg, 'conv'): k_1x1_avg_first, b_1x1_avg_first = transI_fusebn( self.dbb_avg.conv.weight, self.dbb_avg.bn) k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk( k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups) else: k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second return transII_addbranch( (k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged)) def re_parameterize(self): if self.is_repped: return kernel, bias = self.get_equivalent_kernel_bias() self.dbb_reparam = nn.Conv2D( in_channels=self.dbb_origin.conv._in_channels, out_channels=self.dbb_origin.conv._out_channels, kernel_size=self.dbb_origin.conv._kernel_size, stride=self.dbb_origin.conv._stride, padding=self.dbb_origin.conv._padding, dilation=self.dbb_origin.conv._dilation, groups=self.dbb_origin.conv._groups, bias_attr=True) self.dbb_reparam.weight.set_value(kernel) self.dbb_reparam.bias.set_value(bias) self.__delattr__('dbb_origin') self.__delattr__('dbb_avg') if hasattr(self, 'dbb_1x1'): self.__delattr__('dbb_1x1') self.__delattr__('dbb_1x1_kxk') self.is_repped = True