detect.py full changes
parent
ed41f4e212
commit
0f4fb1f426
138
detect.py
138
detect.py
|
@ -1,34 +1,26 @@
|
|||
|
||||
import argparse # to pass arguments on comment line
|
||||
import csv # read and write csv
|
||||
import os # os path, folder and file operations
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import platform
|
||||
import sys # for changing system configurations
|
||||
import paho.mqtt.client as mqtt # for mqtt connection protocol
|
||||
import sys
|
||||
import paho.mqtt.client as mqtt
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
|
||||
file_path = "output.txt"
|
||||
|
||||
FILE = Path(__file__).resolve() # __file__ passed as command line argument
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from ultralytics.utils.plotting import Annotator, colors, save_one_box
|
||||
# Annotator - put bounding box
|
||||
|
||||
from models.common import DetectMultiBackend
|
||||
# to be compatible with raspberry pi
|
||||
|
||||
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
|
||||
# to support different input data types
|
||||
|
||||
from utils.general import (
|
||||
LOGGER,
|
||||
Profile,
|
||||
|
@ -45,13 +37,10 @@ from utils.general import (
|
|||
strip_optimizer,
|
||||
xyxy2xywh,
|
||||
)
|
||||
# for general operations like logging, check image size, open window, and so on
|
||||
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
# cpu or gpu selection; to tune model parameters
|
||||
|
||||
MQTT_BROKER = "broker.hivemq.com"
|
||||
MQTT_PORT = 1883 # to pass unencrypted messages
|
||||
MQTT_PORT = 1883
|
||||
MQTT_TOPIC = "Automation001"
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
|
@ -69,15 +58,64 @@ client.on_connect = on_connect
|
|||
client.on_publish = on_publish
|
||||
client.connect(MQTT_BROKER, MQTT_PORT, 60) # Connect to the broker
|
||||
|
||||
# constants for logic
|
||||
WINDOW_SIZE = 22
|
||||
RECOGNITION_THRESHOLD = 19 # 30 percent of window size
|
||||
# Store detected objects over multiple frames (Sliding Window)
|
||||
from collections import deque
|
||||
|
||||
WINDOW_SIZE = 30
|
||||
RECOGNITION_THRESHOLD = 0.7
|
||||
|
||||
# Buffer to store objects detected over last N frames
|
||||
focus_buffer = deque(maxlen=WINDOW_SIZE) # Track objects over last 5 frames
|
||||
|
||||
def get_focused_object(detections, img_shape):
|
||||
"""
|
||||
Determines the object the user is focusing on by finding the detection
|
||||
closest to the center of the frame and ensuring persistence over frames.
|
||||
|
||||
:param detections: Tensor of shape (N, 6) containing [x1, y1, x2, y2, confidence, class]
|
||||
:param img_shape: Shape of the original image (height, width, channels)
|
||||
:return: The most consistently detected object (class ID) or None if no stable focus.
|
||||
"""
|
||||
if len(detections) == 0:
|
||||
return 80 # No objects detected
|
||||
|
||||
# Step 1: Filter out low-confidence detections
|
||||
detections = detections[detections[:, 4] > 0.5] # Keep only confidence > 50%
|
||||
|
||||
if len(detections) == 0:
|
||||
return 80 # No confident detections
|
||||
|
||||
# Step 2: Calculate object centroids
|
||||
image_center = torch.tensor([img_shape[1] / 2, img_shape[0] / 2]) # (x_center, y_center)
|
||||
centroids = torch.stack([(detections[:, 0] + detections[:, 2]) / 2,
|
||||
(detections[:, 1] + detections[:, 3]) / 2], dim=1)
|
||||
|
||||
# Step 3: Find the object closest to the center
|
||||
distances = torch.norm(centroids - image_center, dim=1) # Euclidean distance
|
||||
min_distance_index = torch.argmin(distances) # Index of the closest object
|
||||
|
||||
focused_object = int(detections[min_distance_index, 5]) # Get class ID of the focused object
|
||||
|
||||
# Step 4: Maintain a sliding window of detected objects
|
||||
focus_buffer.append(focused_object)
|
||||
|
||||
# Step 5: Determine the most frequently appearing object in buffer
|
||||
focus_counts = {obj: focus_buffer.count(obj) for obj in set(focus_buffer)}
|
||||
most_frequent_object = max(focus_counts, key=focus_counts.get) # Object appearing most in buffer
|
||||
|
||||
# Only return if it appears in at least 60% of frames in the buffer
|
||||
if focus_counts[most_frequent_object] >= RECOGNITION_THRESHOLD * len(focus_buffer):
|
||||
return most_frequent_object
|
||||
else:
|
||||
return 80 # No stable focus object
|
||||
|
||||
# 80 class corresponds
|
||||
|
||||
@smart_inference_mode()
|
||||
def run(
|
||||
weights=ROOT / "yolov5s.pt", # model path or triton URL # s because small
|
||||
source=ROOT / "data/images", # file/dir/URL/glob/screen/0(webcam) # path to input images
|
||||
data=ROOT / "data/coco128.yaml", # dataset.yaml path # model configurations, check internet
|
||||
weights=ROOT / "yolov5s.pt", # model path or triton URL
|
||||
source=ROOT / "data/images", # file/dir/URL/glob/screen/0(webcam)
|
||||
data=ROOT / "data/coco128.yaml", # dataset.yaml path
|
||||
imgsz=(640, 640), # inference size (height, width)
|
||||
conf_thres=0.25, # confidence threshold
|
||||
iou_thres=0.45, # NMS IOU threshold
|
||||
|
@ -95,7 +133,7 @@ def run(
|
|||
augment=False, # augmented inference
|
||||
visualize=False, # visualize features
|
||||
update=False, # update all models
|
||||
project=ROOT / "runs/detect", # save results to project/name # path to this file
|
||||
project=ROOT / "runs/detect", # save results to project/name
|
||||
name="exp", # save results to project/name
|
||||
exist_ok=False, # existing project/name ok, do not increment
|
||||
line_thickness=3, # bounding box thickness (pixels)
|
||||
|
@ -106,10 +144,6 @@ def run(
|
|||
vid_stride=1, # video frame-rate stride
|
||||
):
|
||||
|
||||
labels = []
|
||||
best_label = ""
|
||||
best_confidence = 0
|
||||
|
||||
source = str(source)
|
||||
save_img = not nosave and not source.endswith(".txt") # save inference images
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
|
@ -127,6 +161,7 @@ def run(
|
|||
device = select_device(device)
|
||||
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
|
||||
stride, names, pt = model.stride, model.names, model.pt
|
||||
names[80] = 'none'
|
||||
imgsz = check_img_size(imgsz, s=stride) # check image size
|
||||
|
||||
# Dataloader
|
||||
|
@ -135,7 +170,7 @@ def run(
|
|||
view_img = check_imshow(warn=True)
|
||||
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
||||
bs = len(dataset)
|
||||
elif screenshot: # screenshot of computer window
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
|
||||
else:
|
||||
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
|
||||
|
@ -145,7 +180,6 @@ def run(
|
|||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
||||
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
|
||||
for path, im, im0s, vid_cap, s in dataset:
|
||||
|
||||
with dt[0]:
|
||||
im = torch.from_numpy(im).to(model.device)
|
||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
|
@ -191,6 +225,7 @@ def run(
|
|||
|
||||
# Process predictions
|
||||
for i, det in enumerate(pred): # per image
|
||||
|
||||
seen += 1
|
||||
if webcam: # batch_size >= 1
|
||||
p, im0, frame = path[i], im0s[i].copy(), dataset.count
|
||||
|
@ -216,38 +251,15 @@ def run(
|
|||
|
||||
# Write results
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
c = int(cls) # integer class
|
||||
c = get_focused_object(det, im0.shape) # integer class
|
||||
label = names[c] if hide_conf else f"{names[c]}"
|
||||
confidence = float(conf)
|
||||
confidence_str = f"{confidence:.2f}"
|
||||
labels.append((confidence, label))
|
||||
if seen % WINDOW_SIZE == 0:
|
||||
if labels:
|
||||
# Sort labels in descending order of confidence (higher confidence first)
|
||||
labels.sort(reverse=True, key=lambda x: x[0])
|
||||
|
||||
# Count occurrences of each label
|
||||
label_counts = Counter(label for _, label in labels)
|
||||
|
||||
# Find the most frequent label(s)
|
||||
max_frequency = max(label_counts.values())
|
||||
most_frequent_labels = {label for label, count in label_counts.items() if count == max_frequency}
|
||||
|
||||
# Select the highest confidence instance among the most frequent labels
|
||||
for conf, lbl in labels:
|
||||
if lbl in most_frequent_labels:
|
||||
best_label = lbl
|
||||
best_confidence = conf
|
||||
break # Pick the first match (highest confidence among most frequent)
|
||||
|
||||
# Ensure it meets the recognition threshold and confidence requirement
|
||||
if label_counts[best_label] >= RECOGNITION_THRESHOLD and best_confidence > 0.50:
|
||||
msg = f"{best_label}"
|
||||
client.publish(MQTT_TOPIC, msg, retain=True)
|
||||
print("Data sent to MQTT:", msg)
|
||||
with open(file_path, "a") as f: # Append mode
|
||||
f.write(f"{best_label} with confidence: {best_confidence}\n")
|
||||
labels = []
|
||||
|
||||
if confidence > 0.50:
|
||||
msg = f"{label}"
|
||||
# client.publish(MQTT_TOPIC, msg, retain=True)
|
||||
print(f"data sent to mqtt: {msg}")
|
||||
|
||||
if save_csv:
|
||||
write_to_csv(p.name, label, confidence_str)
|
||||
|
@ -298,10 +310,16 @@ def run(
|
|||
save_path = str(Path(save_path).with_suffix(".mp4")) # force *.mp4 suffix on results videos
|
||||
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
|
||||
vid_writer[i].write(im0)
|
||||
|
||||
|
||||
# Print time (inference-only)
|
||||
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1e3:.1f}ms")
|
||||
|
||||
if cv2.waitKey(1) == ord('q'):
|
||||
cv2.destroyAllWindows
|
||||
# raise StopIteration
|
||||
break
|
||||
|
||||
# Print results
|
||||
t = tuple(x.t / seen * 1e3 for x in dt) # speeds per image
|
||||
LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)
|
||||
|
|
Loading…
Reference in New Issue