mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
.fuse() gradient introduction bug fix
This commit is contained in:
parent
c4cb78570c
commit
89655a84f2
@ -104,8 +104,8 @@ def prune(model, amount=0.3):
|
|||||||
|
|
||||||
|
|
||||||
def fuse_conv_and_bn(conv, bn):
|
def fuse_conv_and_bn(conv, bn):
|
||||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||||
with torch.no_grad():
|
|
||||||
# init
|
# init
|
||||||
fusedconv = nn.Conv2d(conv.in_channels,
|
fusedconv = nn.Conv2d(conv.in_channels,
|
||||||
conv.out_channels,
|
conv.out_channels,
|
||||||
@ -113,7 +113,7 @@ def fuse_conv_and_bn(conv, bn):
|
|||||||
stride=conv.stride,
|
stride=conv.stride,
|
||||||
padding=conv.padding,
|
padding=conv.padding,
|
||||||
groups=conv.groups,
|
groups=conv.groups,
|
||||||
bias=True).to(conv.weight.device)
|
bias=True).requires_grad_(False).to(conv.weight.device)
|
||||||
|
|
||||||
# prepare filters
|
# prepare filters
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user