From 89655a84f2b9ef952e957e90fe622de167f37859 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 20 Sep 2020 11:57:19 -0700 Subject: [PATCH] .fuse() gradient introduction bug fix --- utils/torch_utils.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index bc54ab1e2..e3b23f9b8 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -104,28 +104,28 @@ def prune(model, amount=0.3): def fuse_conv_and_bn(conv, bn): - # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ - with torch.no_grad(): - # init - fusedconv = nn.Conv2d(conv.in_channels, - conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - groups=conv.groups, - bias=True).to(conv.weight.device) + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ - # prepare filters - w_conv = conv.weight.clone().view(conv.out_channels, -1) - w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) - fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + # init + fusedconv = nn.Conv2d(conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True).requires_grad_(False).to(conv.weight.device) - # prepare spatial bias - b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias - b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) - fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) - return fusedconv + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv def model_info(model, verbose=False):