mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
PyTorch 1.7.0 Compatibility Updates (#1233)
* torch 1.7.0 compatibility updates * add inference verification
This commit is contained in:
parent
453acdec67
commit
c8c5ef36c9
@ -108,3 +108,11 @@ def yolov5x(pretrained=False, channels=3, classes=80):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
||||||
|
model = model.fuse().eval().autoshape() # for autoshaping of PIL/cv2/np inputs and NMS
|
||||||
|
|
||||||
|
# Verify inference
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
img = Image.open('inference/images/zidane.jpg')
|
||||||
|
y = model(img)
|
||||||
|
print(y[0].shape)
|
||||||
|
@ -136,6 +136,13 @@ def attempt_load(weights, map_location=None):
|
|||||||
attempt_download(w)
|
attempt_download(w)
|
||||||
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
|
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
|
||||||
|
|
||||||
|
# Compatibility updates
|
||||||
|
for m in model.modules():
|
||||||
|
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
||||||
|
m.inplace = True # pytorch 1.7.0 compatibility
|
||||||
|
elif type(m) is Conv:
|
||||||
|
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
||||||
|
|
||||||
if len(model) == 1:
|
if len(model) == 1:
|
||||||
return model[-1] # return model
|
return model[-1] # return model
|
||||||
else:
|
else:
|
||||||
|
@ -165,7 +165,6 @@ class Model(nn.Module):
|
|||||||
print('Fusing layers... ')
|
print('Fusing layers... ')
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if type(m) is Conv and hasattr(m, 'bn'):
|
if type(m) is Conv and hasattr(m, 'bn'):
|
||||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
|
||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||||
delattr(m, 'bn') # remove batchnorm
|
delattr(m, 'bn') # remove batchnorm
|
||||||
m.forward = m.fuseforward # update forward
|
m.forward = m.fuseforward # update forward
|
||||||
|
@ -74,7 +74,7 @@ def initialize_weights(model):
|
|||||||
elif t is nn.BatchNorm2d:
|
elif t is nn.BatchNorm2d:
|
||||||
m.eps = 1e-3
|
m.eps = 1e-3
|
||||||
m.momentum = 0.03
|
m.momentum = 0.03
|
||||||
elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
||||||
m.inplace = True
|
m.inplace = True
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user