.fuse() additional error checking
parent
89655a84f2
commit
2f77cf33f6
|
@ -160,7 +160,7 @@ class Model(nn.Module):
|
||||||
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
||||||
print('Fusing layers... ')
|
print('Fusing layers... ')
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if type(m) is Conv:
|
if type(m) is Conv and hasattr(Conv, 'bn'):
|
||||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||||
delattr(m, 'bn') # remove batchnorm
|
delattr(m, 'bn') # remove batchnorm
|
||||||
|
|
Loading…
Reference in New Issue