`torch.jit.trace()` fix (#9363)
* Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update ci-testing.yml Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9367/head
parent
cafdd18939
commit
23d0456b08
|
@ -119,9 +119,12 @@ jobs:
|
|||
python export.py --weights $m.pt --img 64 --include torchscript # export
|
||||
python - <<EOF
|
||||
import torch
|
||||
im = torch.zeros([1, 3, 64, 64])
|
||||
for path in '$m', '$b':
|
||||
model = torch.hub.load('.', 'custom', path=path, source='local')
|
||||
print(model('data/images/bus.jpg'))
|
||||
model(im) # warmup, build grids for trace
|
||||
torch.jit.trace(model, [im])
|
||||
EOF
|
||||
- name: Test classification
|
||||
shell: bash # for Windows compatibility
|
||||
|
|
|
@ -600,6 +600,7 @@ class AutoShape(nn.Module):
|
|||
if self.pt:
|
||||
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
||||
m.inplace = False # Detect.inplace=False for safe multithread inference
|
||||
m.export = True # do not output loss values
|
||||
|
||||
def _apply(self, fn):
|
||||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
||||
|
|
Loading…
Reference in New Issue