物体检测和跟踪yolov8

问题描述 投票:0回答:1

这里大家我有这个异议检测代码

from ultralytics import YOLO
import streamlit as st
import cv2
from PIL import Image
import tempfile


def _display_detected_frames(conf, model, st_frame, image):
    """
    Display the detected objects on a video frame using the YOLOv8 model.
    :param conf (float): Confidence threshold for object detection.
    :param model (YOLOv8): An instance of the `YOLOv8` class containing the YOLOv8 model.
    :param st_frame (Streamlit object): A Streamlit object to display the detected video.
    :param image (numpy array): A numpy array representing the video frame.
    :return: None
    """
    # Resize the image to a standard size
    image = cv2.resize(image, (720, int(720 * (9 / 16))))

    # Predict the objects in the image using YOLOv8 model
    res = model.predict(image, conf=conf)

    # Plot the detected objects on the video frame
    res_plotted = res[0].plot()
    st_frame.image(res_plotted,
                   caption='Detected Video',
                   channels="BGR",
                   use_column_width=True
                   )


@st.cache_resource
def load_model(model_path):
    """
    Loads a YOLO object detection model from the specified model_path.

    Parameters:
        model_path (str): The path to the YOLO model file.

    Returns:
        A YOLO object detection model.
    """
    model = YOLO("best.pt")
    return model


def infer_uploaded_image(conf, model):
    """
    Execute inference for uploaded image
    :param conf: Confidence of YOLOv8 model
    :param model: An instance of the `YOLOv8` class containing the YOLOv8 model.
    :return: None
    """
    source_img = st.sidebar.file_uploader(
        label="Choose an image...",
        type=("jpg", "jpeg", "png", 'bmp', 'webp')
    )

    col1, col2 = st.columns(2)

    with col1:
        if source_img:
            uploaded_image = Image.open(source_img)
            # adding the uploaded image to the page with caption
            st.image(
                image=source_img,
                caption="Uploaded Image",
                use_column_width=True
            )

    if source_img:
        if st.button("Execution"):
            with st.spinner("Running..."):
                res = model.predict(uploaded_image,
                                    conf=conf)
                boxes = res[0].boxes
                res_plotted = res[0].plot()[:, :, ::-1]

                with col2:
                    st.image(res_plotted,
                             caption="Detected Image",
                             use_column_width=True)
                    try:
                        with st.expander("Detection Results"):
                            for box in boxes:
                                st.write(box.xywh)
                    except Exception as ex:
                        st.write("No image is uploaded yet!")
                        st.write(ex)


def infer_uploaded_video(conf, model):
    """
    Execute inference for uploaded video
    :param conf: Confidence of YOLOv8 model
    :param model: An instance of the `YOLOv8` class containing the YOLOv8 model.
    :return: None
    """
    source_video = st.sidebar.file_uploader(
        label="Choose a video..."
    )

    if source_video:
        st.video(source_video)

    if source_video:
        if st.button("Execution"):
            with st.spinner("Running..."):
                try:
                    tfile = tempfile.NamedTemporaryFile()
                    tfile.write(source_video.read())
                    vid_cap = cv2.VideoCapture(
                        tfile.name)
                    st_frame = st.empty()
                    while (vid_cap.isOpened()):
                        success, image = vid_cap.read()
                        if success:
                            _display_detected_frames(conf,
                                                     model,
                                                     st_frame,
                                                     image
                                                     )
                        else:
                            vid_cap.release()
                            break
                except Exception as e:
                    st.error(f"Error loading video: {e}")


def infer_uploaded_webcam(conf, model):
    """
    Execute inference for webcam.
    :param conf: Confidence of YOLOv8 model
    :param model: An instance of the `YOLOv8` class containing the YOLOv8 model.
    :return: None
    """
    try:
        flag = st.button(
            label="Stop running"
        )
        vid_cap = cv2.VideoCapture(0)  # local camera
        st_frame = st.empty()
        while not flag:
            success, image = vid_cap.read()
            if success:
                _display_detected_frames(
                    conf,
                    model,
                    st_frame,
                    image
                )
            else:
                vid_cap.release()
                break
    except Exception as e:
        st.error(f"Error loading video: {str(e)}")

我需要帮助来使用 YOLOv8 实现对象跟踪逻辑。我有 11 种不同级别的汽车。最初,当一辆车距离很远时,它被识别为大众汽车,置信度得分为 0.3。然而,当汽车靠近时,它被正确识别为宝马。如果我想添加计数逻辑,它会显示两辆车(大众、宝马),但只有一辆车我需要计算每个类别的总车辆数我是 YOLov8 的新手,请帮助我

object-detection yolo object-tracking yolov7 yolov8
1个回答
0
投票

我有更简单的方法使用 Yolov8 模型跟踪对象

from collections import defaultdict
import cv2
import numpy as np
import time
from ultralytics import YOLO
import tkinter as tk
from tkinter import messagebox

# Load the YOLOv8 model
model = YOLO('testing.pt')

# Open the video file
video_path = "path/to/video.mp4"
cap = cv2.VideoCapture(0)

# Store the track history and last detection time
track_history = defaultdict(lambda: {'track': [], 'last_detection_time': time.time()})

# Function to display alert popup and remove the missing track if acknowledged
def show_alert(track_id):
    root = tk.Tk()
    root.withdraw()
    result = messagebox.askquestion("Missing Object", f"Object {track_id} is missing! Do you want to remove it?", icon='warning')
    
    if result == 'yes':
        # Remove the missing track
        del track_history[track_id]
        print(f"Object {track_id}: Tracking ID - {track_id} removed.")

# Loop through the video frames
while cap.isOpened():
    # Read a frame from the video
    success, frame = cap.read()

    if success:
        # Run YOLOv8 tracking on the frame, persisting tracks between frames
        results = model.track(frame, persist=True)

        # Check if results is not None and results[0].boxes is not None
        if results is not None and results[0].boxes is not None:
            # Get the boxes and track IDs
            boxes = results[0].boxes.xywh.cpu()

            # Check if results[0].boxes.id is not None before accessing int() method
            track_ids = results[0].boxes.id.int().cpu().tolist() if results[0].boxes.id is not None else []

            # Visualize the results on the frame
            annotated_frame = results[0].plot()

            # Plot the tracks and print tracking IDs
            for box, track_id in zip(boxes, track_ids):
                x, y, w, h = box
                track_data = track_history[track_id]
                track = track_data['track']
                track.append((float(x), float(y)))  # x, y center point
                track_data['last_detection_time'] = time.time()

                if len(track) > 30:  # retain 90 tracks for 90 frames
                    track.pop(0)

                # Draw the tracking lines
                points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
                cv2.polylines(annotated_frame, [points], isClosed=False, color=(230, 230, 230), thickness=10)

                # Print the tracking ID
                print(f"Object {track_id}: Tracking ID - {track_id}")

            # Check for missing tracks
            for track_id, track_data in list(track_history.items()):  # Use list() to create a copy of the items for safe iteration
                if time.time() - track_data['last_detection_time'] > 3.0:  # if not detected for 3 seconds
                    print(f"Object {track_id}: Tracking ID - {track_id} MISSING")
                    show_alert(track_id)

            # Display the annotated frame
            cv2.imshow("YOLOv8 Tracking", annotated_frame)

            # Break the loop if 'q' is pressed
            if cv2.waitKey(1) & 0xFF == ord("q"):
                break
    else:
        # Break the loop if the end of the video is reached
        break

# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()

我已经使用此代码添加了盐

扭扭

pip install tkinker

要显示带有 id 的丢失对象可以在屏幕上显示弹出窗口

© www.soinside.com 2019 - 2024. All rights reserved.