detect.py full changes

pull/13556/head
arjunramdas2 2025-04-01 19:02:09 +05:30
parent ed41f4e212
commit 0f4fb1f426
1 changed files with 78 additions and 60 deletions

138
detect.py
View File

@ -1,34 +1,26 @@
import argparse # to pass arguments on comment line
import csv # read and write csv
import os # os path, folder and file operations
import argparse
import csv
import os
import platform
import sys # for changing system configurations
import paho.mqtt.client as mqtt # for mqtt connection protocol
import sys
import paho.mqtt.client as mqtt
import json
import time
from pathlib import Path
from collections import Counter
import torch
file_path = "output.txt"
FILE = Path(__file__).resolve() # __file__ passed as command line argument
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from ultralytics.utils.plotting import Annotator, colors, save_one_box
# Annotator - put bounding box
from models.common import DetectMultiBackend
# to be compatible with raspberry pi
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
# to support different input data types
from utils.general import (
LOGGER,
Profile,
@ -45,13 +37,10 @@ from utils.general import (
strip_optimizer,
xyxy2xywh,
)
# for general operations like logging, check image size, open window, and so on
from utils.torch_utils import select_device, smart_inference_mode
# cpu or gpu selection; to tune model parameters
MQTT_BROKER = "broker.hivemq.com"
MQTT_PORT = 1883 # to pass unencrypted messages
MQTT_PORT = 1883
MQTT_TOPIC = "Automation001"
def on_connect(client, userdata, flags, rc):
@ -69,15 +58,64 @@ client.on_connect = on_connect
client.on_publish = on_publish
client.connect(MQTT_BROKER, MQTT_PORT, 60) # Connect to the broker
# constants for logic
WINDOW_SIZE = 22
RECOGNITION_THRESHOLD = 19 # 30 percent of window size
# Store detected objects over multiple frames (Sliding Window)
from collections import deque
WINDOW_SIZE = 30
RECOGNITION_THRESHOLD = 0.7
# Buffer to store objects detected over last N frames
focus_buffer = deque(maxlen=WINDOW_SIZE) # Track objects over last 5 frames
def get_focused_object(detections, img_shape):
"""
Determines the object the user is focusing on by finding the detection
closest to the center of the frame and ensuring persistence over frames.
:param detections: Tensor of shape (N, 6) containing [x1, y1, x2, y2, confidence, class]
:param img_shape: Shape of the original image (height, width, channels)
:return: The most consistently detected object (class ID) or None if no stable focus.
"""
if len(detections) == 0:
return 80 # No objects detected
# Step 1: Filter out low-confidence detections
detections = detections[detections[:, 4] > 0.5] # Keep only confidence > 50%
if len(detections) == 0:
return 80 # No confident detections
# Step 2: Calculate object centroids
image_center = torch.tensor([img_shape[1] / 2, img_shape[0] / 2]) # (x_center, y_center)
centroids = torch.stack([(detections[:, 0] + detections[:, 2]) / 2,
(detections[:, 1] + detections[:, 3]) / 2], dim=1)
# Step 3: Find the object closest to the center
distances = torch.norm(centroids - image_center, dim=1) # Euclidean distance
min_distance_index = torch.argmin(distances) # Index of the closest object
focused_object = int(detections[min_distance_index, 5]) # Get class ID of the focused object
# Step 4: Maintain a sliding window of detected objects
focus_buffer.append(focused_object)
# Step 5: Determine the most frequently appearing object in buffer
focus_counts = {obj: focus_buffer.count(obj) for obj in set(focus_buffer)}
most_frequent_object = max(focus_counts, key=focus_counts.get) # Object appearing most in buffer
# Only return if it appears in at least 60% of frames in the buffer
if focus_counts[most_frequent_object] >= RECOGNITION_THRESHOLD * len(focus_buffer):
return most_frequent_object
else:
return 80 # No stable focus object
# 80 class corresponds
@smart_inference_mode()
def run(
weights=ROOT / "yolov5s.pt", # model path or triton URL # s because small
source=ROOT / "data/images", # file/dir/URL/glob/screen/0(webcam) # path to input images
data=ROOT / "data/coco128.yaml", # dataset.yaml path # model configurations, check internet
weights=ROOT / "yolov5s.pt", # model path or triton URL
source=ROOT / "data/images", # file/dir/URL/glob/screen/0(webcam)
data=ROOT / "data/coco128.yaml", # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
@ -95,7 +133,7 @@ def run(
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project=ROOT / "runs/detect", # save results to project/name # path to this file
project=ROOT / "runs/detect", # save results to project/name
name="exp", # save results to project/name
exist_ok=False, # existing project/name ok, do not increment
line_thickness=3, # bounding box thickness (pixels)
@ -106,10 +144,6 @@ def run(
vid_stride=1, # video frame-rate stride
):
labels = []
best_label = ""
best_confidence = 0
source = str(source)
save_img = not nosave and not source.endswith(".txt") # save inference images
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
@ -127,6 +161,7 @@ def run(
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
names[80] = 'none'
imgsz = check_img_size(imgsz, s=stride) # check image size
# Dataloader
@ -135,7 +170,7 @@ def run(
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot: # screenshot of computer window
elif screenshot:
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
@ -145,7 +180,6 @@ def run(
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
@ -191,6 +225,7 @@ def run(
# Process predictions
for i, det in enumerate(pred): # per image
seen += 1
if webcam: # batch_size >= 1
p, im0, frame = path[i], im0s[i].copy(), dataset.count
@ -216,38 +251,15 @@ def run(
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
c = get_focused_object(det, im0.shape) # integer class
label = names[c] if hide_conf else f"{names[c]}"
confidence = float(conf)
confidence_str = f"{confidence:.2f}"
labels.append((confidence, label))
if seen % WINDOW_SIZE == 0:
if labels:
# Sort labels in descending order of confidence (higher confidence first)
labels.sort(reverse=True, key=lambda x: x[0])
# Count occurrences of each label
label_counts = Counter(label for _, label in labels)
# Find the most frequent label(s)
max_frequency = max(label_counts.values())
most_frequent_labels = {label for label, count in label_counts.items() if count == max_frequency}
# Select the highest confidence instance among the most frequent labels
for conf, lbl in labels:
if lbl in most_frequent_labels:
best_label = lbl
best_confidence = conf
break # Pick the first match (highest confidence among most frequent)
# Ensure it meets the recognition threshold and confidence requirement
if label_counts[best_label] >= RECOGNITION_THRESHOLD and best_confidence > 0.50:
msg = f"{best_label}"
client.publish(MQTT_TOPIC, msg, retain=True)
print("Data sent to MQTT:", msg)
with open(file_path, "a") as f: # Append mode
f.write(f"{best_label} with confidence: {best_confidence}\n")
labels = []
if confidence > 0.50:
msg = f"{label}"
# client.publish(MQTT_TOPIC, msg, retain=True)
print(f"data sent to mqtt: {msg}")
if save_csv:
write_to_csv(p.name, label, confidence_str)
@ -298,10 +310,16 @@ def run(
save_path = str(Path(save_path).with_suffix(".mp4")) # force *.mp4 suffix on results videos
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
vid_writer[i].write(im0)
# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1e3:.1f}ms")
if cv2.waitKey(1) == ord('q'):
cv2.destroyAllWindows
# raise StopIteration
break
# Print results
t = tuple(x.t / seen * 1e3 for x in dt) # speeds per image
LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)