Merge d748a1557b
into fe1d4d9947
commit
83b1bbf814
|
@ -37,6 +37,8 @@ data/images/*
|
|||
|
||||
results*.csv
|
||||
|
||||
output.txt
|
||||
|
||||
# Datasets -------------------------------------------------------------------------------------------------------------
|
||||
coco/
|
||||
coco128/
|
||||
|
|
237
detect.py
237
detect.py
|
@ -1,33 +1,3 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||
|
||||
Usage - sources:
|
||||
$ python detect.py --weights yolov5s.pt --source 0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/LNwODJXcvt4' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||
|
||||
Usage - formats:
|
||||
$ python detect.py --weights yolov5s.pt # PyTorch
|
||||
yolov5s.torchscript # TorchScript
|
||||
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov5s_openvino_model # OpenVINO
|
||||
yolov5s.engine # TensorRT
|
||||
yolov5s.mlpackage # CoreML (macOS-only)
|
||||
yolov5s_saved_model # TensorFlow SavedModel
|
||||
yolov5s.pb # TensorFlow GraphDef
|
||||
yolov5s.tflite # TensorFlow Lite
|
||||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov5s_paddle_model # PaddlePaddle
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
|
@ -35,6 +5,7 @@ import platform
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import torch
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
|
@ -65,6 +36,84 @@ from utils.general import (
|
|||
)
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
MQTT_BROKER = "broker.hivemq.com"
|
||||
MQTT_PORT = 1883
|
||||
MQTT_TOPIC = "Automation001"
|
||||
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
if rc == 0:
|
||||
print("Connected to MQTT broker successfully.")
|
||||
else:
|
||||
print(f"Failed to connect to MQTT broker, return code {rc}")
|
||||
|
||||
|
||||
def on_publish(client, userdata, mid):
|
||||
print(f"Message {mid} published.")
|
||||
|
||||
|
||||
# Initialize MQTT client
|
||||
client = mqtt.Client()
|
||||
client.on_connect = on_connect
|
||||
client.on_publish = on_publish
|
||||
client.connect(MQTT_BROKER, MQTT_PORT, 60) # Connect to the broker
|
||||
|
||||
# Store detected objects over multiple frames (Sliding Window)
|
||||
from collections import deque
|
||||
|
||||
WINDOW_SIZE = 30
|
||||
RECOGNITION_THRESHOLD = 0.7
|
||||
|
||||
# Buffer to store objects detected over last N frames
|
||||
focus_buffer = deque(maxlen=WINDOW_SIZE) # Track objects over last 5 frames
|
||||
|
||||
|
||||
def get_focused_object(detections, img_shape):
|
||||
"""
|
||||
Determines the object the user is focusing on by finding the detection closest to the center of the frame and
|
||||
ensuring persistence over frames.
|
||||
|
||||
:param detections: Tensor of shape (N, 6) containing [x1, y1, x2, y2, confidence, class]
|
||||
:param img_shape: Shape of the original image (height, width, channels)
|
||||
:return: The most consistently detected object (class ID) or None if no stable focus.
|
||||
"""
|
||||
if len(detections) == 0:
|
||||
return 80 # No objects detected
|
||||
|
||||
# Step 1: Filter out low-confidence detections
|
||||
detections = detections[detections[:, 4] > 0.5] # Keep only confidence > 50%
|
||||
|
||||
if len(detections) == 0:
|
||||
return 80 # No confident detections
|
||||
|
||||
# Step 2: Calculate object centroids
|
||||
image_center = torch.tensor([img_shape[1] / 2, img_shape[0] / 2]) # (x_center, y_center)
|
||||
centroids = torch.stack(
|
||||
[(detections[:, 0] + detections[:, 2]) / 2, (detections[:, 1] + detections[:, 3]) / 2], dim=1
|
||||
)
|
||||
|
||||
# Step 3: Find the object closest to the center
|
||||
distances = torch.norm(centroids - image_center, dim=1) # Euclidean distance
|
||||
min_distance_index = torch.argmin(distances) # Index of the closest object
|
||||
|
||||
focused_object = int(detections[min_distance_index, 5]) # Get class ID of the focused object
|
||||
|
||||
# Step 4: Maintain a sliding window of detected objects
|
||||
focus_buffer.append(focused_object)
|
||||
|
||||
# Step 5: Determine the most frequently appearing object in buffer
|
||||
focus_counts = {obj: focus_buffer.count(obj) for obj in set(focus_buffer)}
|
||||
most_frequent_object = max(focus_counts, key=focus_counts.get) # Object appearing most in buffer
|
||||
|
||||
# Only return if it appears in at least 60% of frames in the buffer
|
||||
if focus_counts[most_frequent_object] >= RECOGNITION_THRESHOLD * len(focus_buffer):
|
||||
return most_frequent_object
|
||||
else:
|
||||
return 80 # No stable focus object
|
||||
|
||||
|
||||
# 80 class corresponds
|
||||
|
||||
|
||||
@smart_inference_mode()
|
||||
def run(
|
||||
|
@ -98,56 +147,6 @@ def run(
|
|||
dnn=False, # use OpenCV DNN for ONNX inference
|
||||
vid_stride=1, # video frame-rate stride
|
||||
):
|
||||
"""
|
||||
Runs YOLOv5 detection inference on various sources like images, videos, directories, streams, etc.
|
||||
|
||||
Args:
|
||||
weights (str | Path): Path to the model weights file or a Triton URL. Default is 'yolov5s.pt'.
|
||||
source (str | Path): Input source, which can be a file, directory, URL, glob pattern, screen capture, or webcam
|
||||
index. Default is 'data/images'.
|
||||
data (str | Path): Path to the dataset YAML file. Default is 'data/coco128.yaml'.
|
||||
imgsz (tuple[int, int]): Inference image size as a tuple (height, width). Default is (640, 640).
|
||||
conf_thres (float): Confidence threshold for detections. Default is 0.25.
|
||||
iou_thres (float): Intersection Over Union (IOU) threshold for non-max suppression. Default is 0.45.
|
||||
max_det (int): Maximum number of detections per image. Default is 1000.
|
||||
device (str): CUDA device identifier (e.g., '0' or '0,1,2,3') or 'cpu'. Default is an empty string, which uses the
|
||||
best available device.
|
||||
view_img (bool): If True, display inference results using OpenCV. Default is False.
|
||||
save_txt (bool): If True, save results in a text file. Default is False.
|
||||
save_csv (bool): If True, save results in a CSV file. Default is False.
|
||||
save_conf (bool): If True, include confidence scores in the saved results. Default is False.
|
||||
save_crop (bool): If True, save cropped prediction boxes. Default is False.
|
||||
nosave (bool): If True, do not save inference images or videos. Default is False.
|
||||
classes (list[int]): List of class indices to filter detections by. Default is None.
|
||||
agnostic_nms (bool): If True, perform class-agnostic non-max suppression. Default is False.
|
||||
augment (bool): If True, use augmented inference. Default is False.
|
||||
visualize (bool): If True, visualize feature maps. Default is False.
|
||||
update (bool): If True, update all models' weights. Default is False.
|
||||
project (str | Path): Directory to save results. Default is 'runs/detect'.
|
||||
name (str): Name of the current experiment; used to create a subdirectory within 'project'. Default is 'exp'.
|
||||
exist_ok (bool): If True, existing directories with the same name are reused instead of being incremented. Default is
|
||||
False.
|
||||
line_thickness (int): Thickness of bounding box lines in pixels. Default is 3.
|
||||
hide_labels (bool): If True, do not display labels on bounding boxes. Default is False.
|
||||
hide_conf (bool): If True, do not display confidence scores on bounding boxes. Default is False.
|
||||
half (bool): If True, use FP16 half-precision inference. Default is False.
|
||||
dnn (bool): If True, use OpenCV DNN backend for ONNX inference. Default is False.
|
||||
vid_stride (int): Stride for processing video frames, to skip frames between processing. Default is 1.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from ultralytics import run
|
||||
|
||||
# Run inference on an image
|
||||
run(source='data/images/example.jpg', weights='yolov5s.pt', device='0')
|
||||
|
||||
# Run inference on a video with specific confidence threshold
|
||||
run(source='data/videos/example.mp4', weights='yolov5s.pt', conf_thres=0.4, device='0')
|
||||
```
|
||||
"""
|
||||
source = str(source)
|
||||
save_img = not nosave and not source.endswith(".txt") # save inference images
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
|
@ -165,6 +164,7 @@ def run(
|
|||
device = select_device(device)
|
||||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
|
||||
stride, names, pt = model.stride, model.names, model.pt
|
||||
names[80] = "none"
|
||||
imgsz = check_img_size(imgsz, s=stride) # check image size
|
||||
|
||||
# Dataloader
|
||||
|
@ -253,11 +253,16 @@ def run(
|
|||
|
||||
# Write results
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
c = int(cls) # integer class
|
||||
c = get_focused_object(det, im0.shape) # integer class
|
||||
label = names[c] if hide_conf else f"{names[c]}"
|
||||
confidence = float(conf)
|
||||
confidence_str = f"{confidence:.2f}"
|
||||
|
||||
if confidence > 0.50:
|
||||
msg = f"{label}"
|
||||
# client.publish(MQTT_TOPIC, msg, retain=True)
|
||||
print(f"data sent to mqtt: {msg}")
|
||||
|
||||
if save_csv:
|
||||
write_to_csv(p.name, label, confidence_str)
|
||||
|
||||
|
@ -311,6 +316,11 @@ def run(
|
|||
# Print time (inference-only)
|
||||
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1e3:.1f}ms")
|
||||
|
||||
if cv2.waitKey(1) == ord("q"):
|
||||
cv2.destroyAllWindows
|
||||
# raise StopIteration
|
||||
break
|
||||
|
||||
# Print results
|
||||
t = tuple(x.t / seen * 1e3 for x in dt) # speeds per image
|
||||
LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)
|
||||
|
@ -322,49 +332,6 @@ def run(
|
|||
|
||||
|
||||
def parse_opt():
|
||||
"""
|
||||
Parse command-line arguments for YOLOv5 detection, allowing custom inference options and model configurations.
|
||||
|
||||
Args:
|
||||
--weights (str | list[str], optional): Model path or Triton URL. Defaults to ROOT / 'yolov5s.pt'.
|
||||
--source (str, optional): File/dir/URL/glob/screen/0(webcam). Defaults to ROOT / 'data/images'.
|
||||
--data (str, optional): Dataset YAML path. Provides dataset configuration information.
|
||||
--imgsz (list[int], optional): Inference size (height, width). Defaults to [640].
|
||||
--conf-thres (float, optional): Confidence threshold. Defaults to 0.25.
|
||||
--iou-thres (float, optional): NMS IoU threshold. Defaults to 0.45.
|
||||
--max-det (int, optional): Maximum number of detections per image. Defaults to 1000.
|
||||
--device (str, optional): CUDA device, i.e., '0' or '0,1,2,3' or 'cpu'. Defaults to "".
|
||||
--view-img (bool, optional): Flag to display results. Defaults to False.
|
||||
--save-txt (bool, optional): Flag to save results to *.txt files. Defaults to False.
|
||||
--save-csv (bool, optional): Flag to save results in CSV format. Defaults to False.
|
||||
--save-conf (bool, optional): Flag to save confidences in labels saved via --save-txt. Defaults to False.
|
||||
--save-crop (bool, optional): Flag to save cropped prediction boxes. Defaults to False.
|
||||
--nosave (bool, optional): Flag to prevent saving images/videos. Defaults to False.
|
||||
--classes (list[int], optional): List of classes to filter results by, e.g., '--classes 0 2 3'. Defaults to None.
|
||||
--agnostic-nms (bool, optional): Flag for class-agnostic NMS. Defaults to False.
|
||||
--augment (bool, optional): Flag for augmented inference. Defaults to False.
|
||||
--visualize (bool, optional): Flag for visualizing features. Defaults to False.
|
||||
--update (bool, optional): Flag to update all models in the model directory. Defaults to False.
|
||||
--project (str, optional): Directory to save results. Defaults to ROOT / 'runs/detect'.
|
||||
--name (str, optional): Sub-directory name for saving results within --project. Defaults to 'exp'.
|
||||
--exist-ok (bool, optional): Flag to allow overwriting if the project/name already exists. Defaults to False.
|
||||
--line-thickness (int, optional): Thickness (in pixels) of bounding boxes. Defaults to 3.
|
||||
--hide-labels (bool, optional): Flag to hide labels in the output. Defaults to False.
|
||||
--hide-conf (bool, optional): Flag to hide confidences in the output. Defaults to False.
|
||||
--half (bool, optional): Flag to use FP16 half-precision inference. Defaults to False.
|
||||
--dnn (bool, optional): Flag to use OpenCV DNN for ONNX inference. Defaults to False.
|
||||
--vid-stride (int, optional): Video frame-rate stride, determining the number of frames to skip in between
|
||||
consecutive frames. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed command-line arguments as an argparse.Namespace object.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics import YOLOv5
|
||||
args = YOLOv5.parse_opt()
|
||||
```
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL")
|
||||
parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)")
|
||||
|
@ -407,28 +374,6 @@ def parse_opt():
|
|||
|
||||
|
||||
def main(opt):
|
||||
"""
|
||||
Executes YOLOv5 model inference based on provided command-line arguments, validating dependencies before running.
|
||||
|
||||
Args:
|
||||
opt (argparse.Namespace): Command-line arguments for YOLOv5 detection. See function `parse_opt` for details.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Note:
|
||||
This function performs essential pre-execution checks and initiates the YOLOv5 detection process based on user-specified
|
||||
options. Refer to the usage guide and examples for more information about different sources and formats at:
|
||||
https://github.com/ultralytics/ultralytics
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
main(opt)
|
||||
```
|
||||
"""
|
||||
check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop"))
|
||||
run(**vars(opt))
|
||||
|
||||
|
|
|
@ -145,3 +145,6 @@ close-quotes-on-newline = true
|
|||
[tool.codespell]
|
||||
ignore-words-list = "crate,nd,strack,dota,ane,segway,fo,gool,winn,commend"
|
||||
skip = '*.csv,*venv*,docs/??/,docs/mkdocs_??.yml'
|
||||
|
||||
[tool.setuptools.packages]
|
||||
find = { include = ["models", "data", "segment", "classify"] }
|
|
@ -47,3 +47,4 @@ setuptools>=70.0.0 # Snyk vulnerability fix
|
|||
# mss # screenshots
|
||||
# albumentations>=1.0.3
|
||||
# pycocotools>=2.0.6 # COCO mAP
|
||||
paho-mqtt>=2.1.0
|
Loading…
Reference in New Issue