mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
PyTorch Hub models default to CUDA:0 if available (#2472)
* PyTorch Hub models default to CUDA:0 if available * device as string bug fix
This commit is contained in:
parent
2d41e70e82
commit
9b11f0c58b
@ -12,6 +12,7 @@ import torch
|
|||||||
from models.yolo import Model
|
from models.yolo import Model
|
||||||
from utils.general import set_logging
|
from utils.general import set_logging
|
||||||
from utils.google_utils import attempt_download
|
from utils.google_utils import attempt_download
|
||||||
|
from utils.torch_utils import select_device
|
||||||
|
|
||||||
dependencies = ['torch', 'yaml']
|
dependencies = ['torch', 'yaml']
|
||||||
set_logging()
|
set_logging()
|
||||||
@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape):
|
|||||||
model.names = ckpt['model'].names # set class names attribute
|
model.names = ckpt['model'].names # set class names attribute
|
||||||
if autoshape:
|
if autoshape:
|
||||||
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
||||||
return model
|
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
||||||
|
return model.to(device)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
||||||
|
@ -385,7 +385,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
# Display cache
|
# Display cache
|
||||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
|
||||||
if exists:
|
if exists:
|
||||||
d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
||||||
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
|
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
|
||||||
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
|
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
|
||||||
|
|
||||||
@ -485,7 +485,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
nc += 1
|
nc += 1
|
||||||
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
|
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
|
||||||
|
|
||||||
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \
|
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
|
||||||
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
||||||
|
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
|
@ -79,7 +79,7 @@ def check_git_status():
|
|||||||
f"Use 'git pull' to update or 'git clone {url}' to download latest."
|
f"Use 'git pull' to update or 'git clone {url}' to download latest."
|
||||||
else:
|
else:
|
||||||
s = f'up to date with {url} ✅'
|
s = f'up to date with {url} ✅'
|
||||||
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)
|
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# PyTorch utils
|
# PyTorch utils
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -53,7 +53,7 @@ def git_describe():
|
|||||||
|
|
||||||
def select_device(device='', batch_size=None):
|
def select_device(device='', batch_size=None):
|
||||||
# device = 'cpu' or '0' or '0,1,2,3'
|
# device = 'cpu' or '0' or '0,1,2,3'
|
||||||
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
|
s = f'YOLOv5 🚀 {git_describe()} torch {torch.__version__} ' # string
|
||||||
cpu = device.lower() == 'cpu'
|
cpu = device.lower() == 'cpu'
|
||||||
if cpu:
|
if cpu:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||||
@ -73,7 +73,7 @@ def select_device(device='', batch_size=None):
|
|||||||
else:
|
else:
|
||||||
s += 'CPU\n'
|
s += 'CPU\n'
|
||||||
|
|
||||||
logger.info(s) # skip a line
|
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
||||||
return torch.device('cuda:0' if cuda else 'cpu')
|
return torch.device('cuda:0' if cuda else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user