model fuse
parent
12b0c046d5
commit
c672bef10f
|
@ -21,6 +21,8 @@ def detect(save_img=False):
|
|||
google_utils.attempt_download(weights)
|
||||
model = torch.load(weights, map_location=device)['model']
|
||||
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
|
||||
# model.fuse()
|
||||
model.to(device).eval()
|
||||
|
||||
# Second-stage classifier
|
||||
classify = False
|
||||
|
@ -29,12 +31,6 @@ def detect(save_img=False):
|
|||
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
|
||||
modelc.to(device).eval()
|
||||
|
||||
# Eval mode
|
||||
model.to(device).eval()
|
||||
|
||||
# Fuse Conv2d + BatchNorm2d layers
|
||||
# model.fuse()
|
||||
|
||||
# Half precision
|
||||
half = half and device.type != 'cpu' # half precision only supported on CUDA
|
||||
if half:
|
||||
|
|
Loading…
Reference in New Issue