mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Fix TensorRT --dynamic excess outputs bug (#8869)
* Fix TensorRT --dynamic excess outputs bug Potential fix for https://github.com/ultralytics/yolov5/issues/8790 * Cleanup * Update common.py * Update common.py * New fix
This commit is contained in:
parent
84e7748564
commit
38a6eb6e99
@ -387,13 +387,13 @@ class DetectMultiBackend(nn.Module):
|
||||
context = model.create_execution_context()
|
||||
bindings = OrderedDict()
|
||||
fp16 = False # default updated below
|
||||
dynamic_input = False
|
||||
dynamic = False
|
||||
for index in range(model.num_bindings):
|
||||
name = model.get_binding_name(index)
|
||||
dtype = trt.nptype(model.get_binding_dtype(index))
|
||||
if model.binding_is_input(index):
|
||||
if -1 in tuple(model.get_binding_shape(index)): # dynamic
|
||||
dynamic_input = True
|
||||
dynamic = True
|
||||
context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2]))
|
||||
if dtype == np.float16:
|
||||
fp16 = True
|
||||
@ -471,12 +471,14 @@ class DetectMultiBackend(nn.Module):
|
||||
im = im.cpu().numpy() # FP32
|
||||
y = self.executable_network([im])[self.output_layer]
|
||||
elif self.engine: # TensorRT
|
||||
if im.shape != self.bindings['images'].shape and self.dynamic_input:
|
||||
self.context.set_binding_shape(self.model.get_binding_index('images'), im.shape) # reshape if dynamic
|
||||
if self.dynamic and im.shape != self.bindings['images'].shape:
|
||||
i_in, i_out = (self.model.get_binding_index(x) for x in ('images', 'output'))
|
||||
self.context.set_binding_shape(i_in, im.shape) # reshape if dynamic
|
||||
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
||||
assert im.shape == self.bindings['images'].shape, (
|
||||
f"image shape {im.shape} exceeds model max shape {self.bindings['images'].shape}" if self.dynamic_input
|
||||
else f"image shape {im.shape} does not match model shape {self.bindings['images'].shape}")
|
||||
self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out)))
|
||||
s = self.bindings['images'].shape
|
||||
assert im.shape == s, f"image shape {im.shape} " + \
|
||||
f"exceeds model max shape {s}" if self.dynamic else f"does not match model shape {s}"
|
||||
self.binding_addrs['images'] = int(im.data_ptr())
|
||||
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||
y = self.bindings['output'].data
|
||||
|
Loading…
x
Reference in New Issue
Block a user