model fuse
parent
12b0c046d5
commit
c672bef10f
|
@ -21,6 +21,8 @@ def detect(save_img=False):
|
||||||
google_utils.attempt_download(weights)
|
google_utils.attempt_download(weights)
|
||||||
model = torch.load(weights, map_location=device)['model']
|
model = torch.load(weights, map_location=device)['model']
|
||||||
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
|
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
|
||||||
|
# model.fuse()
|
||||||
|
model.to(device).eval()
|
||||||
|
|
||||||
# Second-stage classifier
|
# Second-stage classifier
|
||||||
classify = False
|
classify = False
|
||||||
|
@ -29,12 +31,6 @@ def detect(save_img=False):
|
||||||
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
|
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
|
||||||
modelc.to(device).eval()
|
modelc.to(device).eval()
|
||||||
|
|
||||||
# Eval mode
|
|
||||||
model.to(device).eval()
|
|
||||||
|
|
||||||
# Fuse Conv2d + BatchNorm2d layers
|
|
||||||
# model.fuse()
|
|
||||||
|
|
||||||
# Half precision
|
# Half precision
|
||||||
half = half and device.type != 'cpu' # half precision only supported on CUDA
|
half = half and device.type != 'cpu' # half precision only supported on CUDA
|
||||||
if half:
|
if half:
|
||||||
|
|
Loading…
Reference in New Issue