.fuse() additional error checking

pull/1007/head
Glenn Jocher 2020-09-20 12:04:20 -07:00
parent 89655a84f2
commit 2f77cf33f6
1 changed files with 1 additions and 1 deletions

View File

@ -160,7 +160,7 @@ class Model(nn.Module):
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print('Fusing layers... ')
for m in self.model.modules():
if type(m) is Conv:
if type(m) is Conv and hasattr(Conv, 'bn'):
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm