.fuse() gradient introduction bug fix

pull/1007/head
Glenn Jocher 2020-09-20 11:57:19 -07:00
parent c4cb78570c
commit 89655a84f2
1 changed files with 19 additions and 19 deletions

View File

@ -104,28 +104,28 @@ def prune(model, amount=0.3):
def fuse_conv_and_bn(conv, bn): def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():
# init
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).to(conv.weight.device)
# prepare filters # init
w_conv = conv.weight.clone().view(conv.out_channels, -1) fusedconv = nn.Conv2d(conv.in_channels,
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) conv.out_channels,
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# prepare spatial bias # prepare filters
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias w_conv = conv.weight.clone().view(conv.out_channels, -1)
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
return fusedconv # prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def model_info(model, verbose=False): def model_info(model, verbose=False):