使用TensorFlow检测模型,我尝试拉取category_name变量以使用指示器在找到特定category_name时发出警报。但是,我在提取变量时遇到了问题。目前使用的是树莓派 4 model b (bullseye OS)
这是我正在运行的主要代码:
import argparse
import sys
import time
import cv2
import numpy as np
import pygame.mixer
from tflite_support.task import core
from tflite_support.task import processor
from tflite_support.task import vision
import utils
def run(model: str, camera_id: int, width: int, height: int, num_threads: int,
enable_edgetpu: bool, alarm_label: str, cat_Name: str) -> None:
# Variables to calculate FPS
counter, fps = 0, 0
start_time = time.time()
# Start capturing video input from the camera
cap = cv2.VideoCapture(camera_id)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
# Visualization parameters
row_size = 20 # pixels
left_margin = 24 # pixels
text_color = (225, 0, 0) # red
font_size = 1
font_thickness = 1
fps_avg_frame_count = 10
# Initialize the object detection model
base_options = core.BaseOptions(
file_name=model, use_coral=enable_edgetpu, num_threads=num_threads)
detection_options = processor.DetectionOptions(
max_results=3, score_threshold=0.3)
options = vision.ObjectDetectorOptions(
base_options=base_options, detection_options=detection_options)
detector = vision.ObjectDetector.create_from_options(options)
# Initialize pygame mixer
pygame.mixer.init()
alarm_sound = pygame.mixer.Sound("alarm1.wav") # Load the alarm sound
# Continuously capture images from the camera and run inference
while cap.isOpened():
success, image = cap.read()
if not success:
sys.exit(
'ERROR: Unable to read from webcam. Please verify your webcam settings.'
)
counter += 1
image = cv2.flip(image, 1)
# Convert the image from BGR to RGB as required by the TFLite model.
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Create a TensorImage object from the RGB image.
input_tensor = vision.TensorImage.create_from_array(rgb_image)
# Run object detection estimation using the model.
detection_result = detector.detect(input_tensor)
#Check if the specific label is detected
alarm_triggered = False
for obj in detection_result.detections:
print("Object 1: ", obj,"\n")
for obj2 in obj.categories:
print("Object 2: ", obj2, "\n")
for obj3 in obj2.category_name:
print("Object 3: ", obj3)
if obj3 == alarm_label:
alarm_triggered = True
# Trigger the alarm sound
alarm_sound.play()
# Trigger Visual Componet
cv2.putText(image, "ALERT: " + alarm_label + " detected!", (50,50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
# Draw keypoints and edges on input image
image = utils.visualize(image, detection_result)
# Calculate the FPS
if counter % fps_avg_frame_count == 0:
end_time = time.time()
fps = fps_avg_frame_count / (end_time - start_time)
start_time = time.time()
# Show the FPS
fps_text = 'FPS = {:.1f}'.format(fps)
text_location = (left_margin, row_size)
cv2.putText(image, fps_text, text_location, cv2.FONT_HERSHEY_PLAIN,
font_size, text_color, font_thickness)
# Stop the program if the ESC key is pressed.
if cv2.waitKey(1) == 27:
break
cv2.imshow('object_detector', image)
# Release the camera and close the window
cap.release()
cv2.destroyAllWindows()
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--model',
help='Path of the object detection model.',
required=False,
default='efficientdet_lite0.tflite')
parser.add_argument(
'--cameraId', help='Id of camera.', required=False, type=int, default=0)
parser.add_argument(
'--frameWidth',
help='Width of frame to capture from camera.',
required=False,
type=int,
default=800)
parser.add_argument(
'--frameHeight',
help='Height of frame to capture from camera.',
required=False,
type=int,
default=600)
parser.add_argument(
'--numThreads',
help='Number of CPU threads to run the model.',
required=False,
type=int,
default=4)
parser.add_argument(
'--enableEdgeTPU',
help='Whether to run the model on EdgeTPU.',
action='store_true',
required=False,
default=False)
parser.add_argument(
'--alarmLabel',
help='Label to trigger alarm.',
required=False,
type=str,
default='person')
parser.add_argument(
'--catName',
help='Object Label Detected.',
required=False,
type=str,
default='cup')
args = parser.parse_args()
run(args.model, args.cameraId, args.frameWidth, args.frameHeight,
args.numThreads, args.enableEdgeTPU, args.alarmLabel, args.catName)
if __name__ == '__main__':
main()
这是对边界框进行分类和应用的检测模型:
import cv2
import numpy as np
from tflite_support.task import processor
_MARGIN = 10 # pixels
_ROW_SIZE = 10 # pixels
_FONT_SIZE = 1
_FONT_THICKNESS = 1
_TEXT_COLOR = (0, 255, 0) # red
def classify(
image: np.ndarray,
detection_result: processor.DetectionResult,
) -> np.ndarray:
for detection in detection_result.detections:
# Draw bounding_box
bbox = detection.bounding_box
start_point = bbox.origin_x, bbox.origin_y
end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height
cv2.rectangle(image, start_point, end_point, _TEXT_COLOR, 3)
# Draw label and score
category = detection.categories[0]
category_name = category.category_name
return category
def visualize(
image: np.ndarray,
detection_result: processor.DetectionResult,
) -> np.ndarray:
for detection in detection_result.detections:
# Draw bounding_box
bbox = detection.bounding_box
start_point = bbox.origin_x, bbox.origin_y
end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height
cv2.rectangle(image, start_point, end_point, _TEXT_COLOR, 3)
# Draw label and score
category = detection.categories[0]
category_name = category.category_name
probability = round(category.score, 2)
result_text = category_name + ' (' + str(probability) + ')'
text_location = (_MARGIN + bbox.origin_x,
_MARGIN + _ROW_SIZE + bbox.origin_y)
cv2.putText(image, result_text, text_location, cv2.FONT_HERSHEY_PLAIN,
_FONT_SIZE, _TEXT_COLOR, _FONT_THICKNESS)
return image
这是我当前的输出
Object 1: Detection(bounding_box=BoundingBox(origin_x=364, origin_y=386, width=179, height=85), categories=[Category(index=14, score=0.3125, display_name='', category_name='bench')])
Object 2: Category(index=14, score=0.3125, display_name='', category_name='bench')
Object 3: b
Object 3: e
Object 3: n
Object 3: c
Object 3: h
我的问题来自 Object3 标签。我希望它显示为“对象 3:长凳”,而不是每次迭代一个字母。
对于我来说,你不需要最后一个
for
循环并直接运行
obj3 = obj2.category_name
喜欢
obj3 = obj2.category_name
print("Object 3: ", obj2.category_name)
if obj3 == alarm_label:
alarm_triggered = True
# Trigger the alarm sound
alarm_sound.play()
# Trigger Visual Componet
cv2.putText(image, "ALERT: " + alarm_label + " detected!", (50,50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)