diff --git a/models/yolo.py b/models/yolo.py index 46f1375e5..a9dc539bf 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -162,7 +162,7 @@ class Model(nn.Module): def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print('Fusing layers... ') for m in self.model.modules(): - if type(m) is Conv and hasattr(Conv, 'bn'): + if type(m) is Conv and hasattr(m, 'bn'): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm