From 0f4fb1f4262449e723893abb6d834ef661e533a9 Mon Sep 17 00:00:00 2001 From: arjunramdas2 Date: Tue, 1 Apr 2025 19:02:09 +0530 Subject: [PATCH] detect.py full changes --- detect.py | 138 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 78 insertions(+), 60 deletions(-) diff --git a/detect.py b/detect.py index 0309a158a..687a71b32 100644 --- a/detect.py +++ b/detect.py @@ -1,34 +1,26 @@ -import argparse # to pass arguments on comment line -import csv # read and write csv -import os # os path, folder and file operations +import argparse +import csv +import os import platform -import sys # for changing system configurations -import paho.mqtt.client as mqtt # for mqtt connection protocol +import sys +import paho.mqtt.client as mqtt import json import time from pathlib import Path -from collections import Counter import torch -file_path = "output.txt" - -FILE = Path(__file__).resolve() # __file__ passed as command line argument +FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from ultralytics.utils.plotting import Annotator, colors, save_one_box -# Annotator - put bounding box from models.common import DetectMultiBackend -# to be compatible with raspberry pi - from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams -# to support different input data types - from utils.general import ( LOGGER, Profile, @@ -45,13 +37,10 @@ from utils.general import ( strip_optimizer, xyxy2xywh, ) -# for general operations like logging, check image size, open window, and so on - from utils.torch_utils import select_device, smart_inference_mode -# cpu or gpu selection; to tune model parameters MQTT_BROKER = "broker.hivemq.com" -MQTT_PORT = 1883 # to pass unencrypted messages +MQTT_PORT = 1883 MQTT_TOPIC = "Automation001" def on_connect(client, userdata, flags, rc): @@ -69,15 +58,64 @@ client.on_connect = on_connect client.on_publish = on_publish client.connect(MQTT_BROKER, MQTT_PORT, 60) # Connect to the broker -# constants for logic -WINDOW_SIZE = 22 -RECOGNITION_THRESHOLD = 19 # 30 percent of window size +# Store detected objects over multiple frames (Sliding Window) +from collections import deque + +WINDOW_SIZE = 30 +RECOGNITION_THRESHOLD = 0.7 + +# Buffer to store objects detected over last N frames +focus_buffer = deque(maxlen=WINDOW_SIZE) # Track objects over last 5 frames + +def get_focused_object(detections, img_shape): + """ + Determines the object the user is focusing on by finding the detection + closest to the center of the frame and ensuring persistence over frames. + + :param detections: Tensor of shape (N, 6) containing [x1, y1, x2, y2, confidence, class] + :param img_shape: Shape of the original image (height, width, channels) + :return: The most consistently detected object (class ID) or None if no stable focus. + """ + if len(detections) == 0: + return 80 # No objects detected + + # Step 1: Filter out low-confidence detections + detections = detections[detections[:, 4] > 0.5] # Keep only confidence > 50% + + if len(detections) == 0: + return 80 # No confident detections + + # Step 2: Calculate object centroids + image_center = torch.tensor([img_shape[1] / 2, img_shape[0] / 2]) # (x_center, y_center) + centroids = torch.stack([(detections[:, 0] + detections[:, 2]) / 2, + (detections[:, 1] + detections[:, 3]) / 2], dim=1) + + # Step 3: Find the object closest to the center + distances = torch.norm(centroids - image_center, dim=1) # Euclidean distance + min_distance_index = torch.argmin(distances) # Index of the closest object + + focused_object = int(detections[min_distance_index, 5]) # Get class ID of the focused object + + # Step 4: Maintain a sliding window of detected objects + focus_buffer.append(focused_object) + + # Step 5: Determine the most frequently appearing object in buffer + focus_counts = {obj: focus_buffer.count(obj) for obj in set(focus_buffer)} + most_frequent_object = max(focus_counts, key=focus_counts.get) # Object appearing most in buffer + + # Only return if it appears in at least 60% of frames in the buffer + if focus_counts[most_frequent_object] >= RECOGNITION_THRESHOLD * len(focus_buffer): + return most_frequent_object + else: + return 80 # No stable focus object + +# 80 class corresponds @smart_inference_mode() def run( - weights=ROOT / "yolov5s.pt", # model path or triton URL # s because small - source=ROOT / "data/images", # file/dir/URL/glob/screen/0(webcam) # path to input images - data=ROOT / "data/coco128.yaml", # dataset.yaml path # model configurations, check internet + weights=ROOT / "yolov5s.pt", # model path or triton URL + source=ROOT / "data/images", # file/dir/URL/glob/screen/0(webcam) + data=ROOT / "data/coco128.yaml", # dataset.yaml path imgsz=(640, 640), # inference size (height, width) conf_thres=0.25, # confidence threshold iou_thres=0.45, # NMS IOU threshold @@ -95,7 +133,7 @@ def run( augment=False, # augmented inference visualize=False, # visualize features update=False, # update all models - project=ROOT / "runs/detect", # save results to project/name # path to this file + project=ROOT / "runs/detect", # save results to project/name name="exp", # save results to project/name exist_ok=False, # existing project/name ok, do not increment line_thickness=3, # bounding box thickness (pixels) @@ -106,10 +144,6 @@ def run( vid_stride=1, # video frame-rate stride ): - labels = [] - best_label = "" - best_confidence = 0 - source = str(source) save_img = not nosave and not source.endswith(".txt") # save inference images is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) @@ -127,6 +161,7 @@ def run( device = select_device(device) model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) stride, names, pt = model.stride, model.names, model.pt + names[80] = 'none' imgsz = check_img_size(imgsz, s=stride) # check image size # Dataloader @@ -135,7 +170,7 @@ def run( view_img = check_imshow(warn=True) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) bs = len(dataset) - elif screenshot: # screenshot of computer window + elif screenshot: dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) else: dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) @@ -145,7 +180,6 @@ def run( model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device)) for path, im, im0s, vid_cap, s in dataset: - with dt[0]: im = torch.from_numpy(im).to(model.device) im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 @@ -191,6 +225,7 @@ def run( # Process predictions for i, det in enumerate(pred): # per image + seen += 1 if webcam: # batch_size >= 1 p, im0, frame = path[i], im0s[i].copy(), dataset.count @@ -216,38 +251,15 @@ def run( # Write results for *xyxy, conf, cls in reversed(det): - c = int(cls) # integer class + c = get_focused_object(det, im0.shape) # integer class label = names[c] if hide_conf else f"{names[c]}" confidence = float(conf) confidence_str = f"{confidence:.2f}" - labels.append((confidence, label)) - if seen % WINDOW_SIZE == 0: - if labels: - # Sort labels in descending order of confidence (higher confidence first) - labels.sort(reverse=True, key=lambda x: x[0]) - - # Count occurrences of each label - label_counts = Counter(label for _, label in labels) - - # Find the most frequent label(s) - max_frequency = max(label_counts.values()) - most_frequent_labels = {label for label, count in label_counts.items() if count == max_frequency} - - # Select the highest confidence instance among the most frequent labels - for conf, lbl in labels: - if lbl in most_frequent_labels: - best_label = lbl - best_confidence = conf - break # Pick the first match (highest confidence among most frequent) - - # Ensure it meets the recognition threshold and confidence requirement - if label_counts[best_label] >= RECOGNITION_THRESHOLD and best_confidence > 0.50: - msg = f"{best_label}" - client.publish(MQTT_TOPIC, msg, retain=True) - print("Data sent to MQTT:", msg) - with open(file_path, "a") as f: # Append mode - f.write(f"{best_label} with confidence: {best_confidence}\n") - labels = [] + + if confidence > 0.50: + msg = f"{label}" + # client.publish(MQTT_TOPIC, msg, retain=True) + print(f"data sent to mqtt: {msg}") if save_csv: write_to_csv(p.name, label, confidence_str) @@ -298,10 +310,16 @@ def run( save_path = str(Path(save_path).with_suffix(".mp4")) # force *.mp4 suffix on results videos vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) vid_writer[i].write(im0) + # Print time (inference-only) LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1e3:.1f}ms") + if cv2.waitKey(1) == ord('q'): + cv2.destroyAllWindows + # raise StopIteration + break + # Print results t = tuple(x.t / seen * 1e3 for x in dt) # speeds per image LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)