Add `--source screen` for screenshot inference (#9542)
* add screenshot as source * fix: screen number support * Fix: mutiple screen specific area * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * parse screen args in LoadScreenshots * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * sequence+ '_' as file name for save-txt save-crop * screenshot as stream * Update requirements.txt Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update dataloaders.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update dataloaders.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update dataloaders.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update README.md Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update tutorial.ipynb Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: xin <xin@zhiyoung.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9566/head
parent
b25d5a75f2
commit
30fa9b610a
|
@ -107,6 +107,7 @@ the latest YOLOv5 [release](https://github.com/ultralytics/yolov5/releases) and
|
|||
python detect.py --source 0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||
|
|
|
@ -42,7 +42,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|||
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.augmentations import classify_transforms
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||
increment_path, print_args, strip_optimizer)
|
||||
from utils.plots import Annotator
|
||||
|
@ -52,7 +52,7 @@ from utils.torch_utils import select_device, smart_inference_mode
|
|||
@smart_inference_mode()
|
||||
def run(
|
||||
weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
|
||||
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||
imgsz=(224, 224), # inference size (height, width)
|
||||
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||
|
@ -74,6 +74,7 @@ def run(
|
|||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
|
||||
|
@ -91,6 +92,8 @@ def run(
|
|||
if webcam:
|
||||
view_img = check_imshow()
|
||||
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
|
||||
else:
|
||||
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
|
||||
bs = len(dataset) # batch_size
|
||||
|
@ -187,7 +190,7 @@ def run(
|
|||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
|
||||
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[224], help='inference size h,w')
|
||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
|
|
|
@ -40,7 +40,7 @@ if str(ROOT) not in sys.path:
|
|||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
|
@ -50,7 +50,7 @@ from utils.torch_utils import select_device, smart_inference_mode
|
|||
@smart_inference_mode()
|
||||
def run(
|
||||
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
|
||||
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||
imgsz=(640, 640), # inference size (height, width)
|
||||
conf_thres=0.25, # confidence threshold
|
||||
|
@ -82,6 +82,7 @@ def run(
|
|||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
|
||||
|
@ -99,6 +100,8 @@ def run(
|
|||
if webcam:
|
||||
view_img = check_imshow()
|
||||
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
|
||||
else:
|
||||
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
||||
bs = len(dataset) # batch_size
|
||||
|
@ -212,7 +215,7 @@ def run(
|
|||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
|
||||
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
|
||||
|
|
|
@ -38,6 +38,7 @@ seaborn>=0.11.0
|
|||
ipython # interactive notebook
|
||||
psutil # system utilization
|
||||
thop>=0.1.1 # FLOPs computation
|
||||
# mss # screenshots
|
||||
# albumentations>=1.0.3
|
||||
# pycocotools>=2.0 # COCO mAP
|
||||
# roboflow
|
||||
|
|
|
@ -40,7 +40,7 @@ if str(ROOT) not in sys.path:
|
|||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
|
||||
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
|
@ -51,7 +51,7 @@ from utils.torch_utils import select_device, smart_inference_mode
|
|||
@smart_inference_mode()
|
||||
def run(
|
||||
weights=ROOT / 'yolov5s-seg.pt', # model.pt path(s)
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
|
||||
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
|
||||
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||
imgsz=(640, 640), # inference size (height, width)
|
||||
conf_thres=0.25, # confidence threshold
|
||||
|
@ -84,6 +84,7 @@ def run(
|
|||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
|
||||
|
@ -101,6 +102,8 @@ def run(
|
|||
if webcam:
|
||||
view_img = check_imshow()
|
||||
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
|
||||
else:
|
||||
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
||||
bs = len(dataset) # batch_size
|
||||
|
@ -222,7 +225,7 @@ def run(
|
|||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-seg.pt', help='model path(s)')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
|
||||
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
|
||||
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
|
||||
|
|
|
@ -445,6 +445,7 @@
|
|||
"python detect.py --source 0 # webcam\n",
|
||||
" img.jpg # image \n",
|
||||
" vid.mp4 # video\n",
|
||||
" screen # screenshot\n",
|
||||
" path/ # directory\n",
|
||||
" 'path/*.jpg' # glob\n",
|
||||
" 'https://youtu.be/Zgi9g1ksQHc' # YouTube\n",
|
||||
|
|
|
@ -185,6 +185,55 @@ class _RepeatSampler:
|
|||
yield from iter(self.sampler)
|
||||
|
||||
|
||||
class LoadScreenshots:
|
||||
# YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
|
||||
def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
|
||||
# source = [screen_number left top width height] (pixels)
|
||||
check_requirements('mss')
|
||||
import mss
|
||||
|
||||
source, *params = source.split()
|
||||
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||||
if len(params) == 1:
|
||||
self.screen = int(params[0])
|
||||
elif len(params) == 4:
|
||||
left, top, width, height = (int(x) for x in params)
|
||||
elif len(params) == 5:
|
||||
self.screen, left, top, width, height = (int(x) for x in params)
|
||||
self.img_size = img_size
|
||||
self.stride = stride
|
||||
self.transforms = transforms
|
||||
self.auto = auto
|
||||
self.mode = 'stream'
|
||||
self.frame = 0
|
||||
self.sct = mss.mss()
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||
self.width = width or monitor["width"]
|
||||
self.height = height or monitor["height"]
|
||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
# mss screen capture: get raw pixels from the screen as np array
|
||||
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||
|
||||
if self.transforms:
|
||||
im = self.transforms(im0) # transforms
|
||||
else:
|
||||
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
|
||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
self.frame += 1
|
||||
return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
|
||||
|
||||
|
||||
class LoadImages:
|
||||
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
|
|
Loading…
Reference in New Issue