mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
torch.jit.trace()
fix (#9363)
* Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update ci-testing.yml Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
cafdd18939
commit
23d0456b08
3
.github/workflows/ci-testing.yml
vendored
3
.github/workflows/ci-testing.yml
vendored
@ -119,9 +119,12 @@ jobs:
|
|||||||
python export.py --weights $m.pt --img 64 --include torchscript # export
|
python export.py --weights $m.pt --img 64 --include torchscript # export
|
||||||
python - <<EOF
|
python - <<EOF
|
||||||
import torch
|
import torch
|
||||||
|
im = torch.zeros([1, 3, 64, 64])
|
||||||
for path in '$m', '$b':
|
for path in '$m', '$b':
|
||||||
model = torch.hub.load('.', 'custom', path=path, source='local')
|
model = torch.hub.load('.', 'custom', path=path, source='local')
|
||||||
print(model('data/images/bus.jpg'))
|
print(model('data/images/bus.jpg'))
|
||||||
|
model(im) # warmup, build grids for trace
|
||||||
|
torch.jit.trace(model, [im])
|
||||||
EOF
|
EOF
|
||||||
- name: Test classification
|
- name: Test classification
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
|
@ -600,6 +600,7 @@ class AutoShape(nn.Module):
|
|||||||
if self.pt:
|
if self.pt:
|
||||||
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
||||||
m.inplace = False # Detect.inplace=False for safe multithread inference
|
m.inplace = False # Detect.inplace=False for safe multithread inference
|
||||||
|
m.export = True # do not output loss values
|
||||||
|
|
||||||
def _apply(self, fn):
|
def _apply(self, fn):
|
||||||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user