import React, { useRef, useEffect, useState } from 'react';
import * as tf from '@tensorflow/tfjs';
import * as posenet from '@tensorflow-models/posenet';
const WebcamWithSegmentation = () => {
const webcamRef = useRef(null);
const canvasRef = useRef(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);
const netRef = useRef(null);
const drawBoundingBox = (ctx, box) => {
if (!box) {
return;
}
console.log('Drawing bounding box:', box);
// Ensure non-negative starting point
const startX = Math.max(0, box.x);
const startY = Math.max(0, box.y);
// Calculate adjusted width and height
const adjustedWidth = box.width - Math.max(0, -box.x);
const adjustedHeight = box.height - Math.max(0, -box.y);
ctx.strokeStyle = 'red';
ctx.lineWidth = 4;
ctx.strokeRect(startX, startY, adjustedWidth, adjustedHeight);
};
const getBoundingBox = (keypoints,minScore = 0.5) => {
const validKeypoints = keypoints.filter(point => point.score > minScore);
if (validKeypoints.length === 0) {
return null; // No keypoints with sufficient score
}
const minX = Math.min(...validKeypoints.map(point => point.position.x));
const minY = Math.min(...validKeypoints.map(point => point.position.y));
const maxX = Math.max(...validKeypoints.map(point => point.position.x));
const maxY = Math.max(...validKeypoints.map(point => point.position.y));
return {
x: minX,
y: minY,
width: maxX - minX,
height: maxY - minY,
};
};
const drawPoints = (ctx, keypoints, minScore) => {
keypoints.forEach(keypoint => {
if (keypoint.score >= minScore) {
const { x, y } = keypoint.position;
ctx.beginPath();
ctx.arc(x, y, 5, 0, 2 * Math.PI);
ctx.fillStyle = 'red';
ctx.fill();
}
});
};
useEffect(() => {
const canvas = canvasRef.current;
const context = canvas.getContext('2d');
let net;
const loadPoseNet = async () => {
try {
await tf.setBackend('webgl');
net = await posenet.load();
console.log('PoseNet loaded successfully');
netRef.current = net;
setLoading(false);
} catch (error) {
console.error('Error loading PoseNet:', error);
setError(`Error l
Loading PoseNet: ${error.message}`);
setLoading(false);
}
};
const setupWebcam = async () => {
try {
const stream = await navigator.mediaDevices.getUserMedia({ video: true });
const video = webcamRef.current;
video.srcObject = stream;
return new Promise((resolve) => {
video.onloadedmetadata = () => {
resolve(video);
};
});
} catch (error) {
console.error('Error accessing webcam:', error);
setError(`Error accessing webcam: ${error.message}`);
setLoading(false);
return null;
}
};
const drawWebcam = () => {
const video = webcamRef.current;
const net = netRef.current;
if (!video || !video.videoWidth || !video.videoHeight || !net) {
return;
}
const { videoWidth, videoHeight } = video;
console.log('Video Dimensions:', videoWidth, videoHeight);
if (canvas.width !== videoWidth || canvas.height !== videoHeight) {
canvas.width = videoWidth;
canvas.height = videoHeight;
}
context.clearRect(0, 0, videoWidth, videoHeight);
context.drawImage(video, 0, 0, videoWidth, videoHeight);
net.estimateSinglePose(video, { flipHorizontal: false })
.then(pose => {
/* pose.keypoints.forEach(part => {
console.log(`${part.part} score: ${part.score}`);
});*/
const isFullBodyVisible =
pose.keypoints.find(part => part.part === 'leftAnkle').score >= 0.5 &&
pose.keypoints.find(part => part.part === 'rightAnkle').score >= 0.5 &&
pose.keypoints.find(part => part.part === 'leftShoulder').score >= 0.5 &&
pose.keypoints.find(part => part.part === 'rightShoulder').score >= 0.5;
if (isFullBodyVisible) {
const boundingBox = getBoundingBox(pose.keypoints);
console.log('Bounding Box:', boundingBox);
drawBoundingBox(context, boundingBox);
}
drawPoints(context, pose.keypoints, 0.5);
})
.finally(() => {
requestAnimationFrame(drawWebcam);
});
};
const init = async () => {
const video = await setupWebcam();
if (!video) {
return;
}
await loadPoseNet();
video.play();
requestAnimationFrame(drawWebcam);
};
init();
return () => {
// Cleanup when the component is unmounted
if (webcamRef.current) {
const video = webcamRef.current;
if (video.srcObject) {
const stream = video.srcObject;
const tracks = stream.getTracks();
tracks.forEach(track => track.stop());
}
}
if (netRef.current) {
1. tf.dispose([netRef.current]);
}
};
}, []);
return (
<div>
{loading && <p>Loading PoseNet...</p>}
{error && <p>{error}</p>}
<video ref={webcamRef} autoPlay width={640} height={480} />
<canvas ref={canvasRef} style={{ position: 'absolute', left: 0, top: 0 }} />
</div>
);
};
export default WebcamWithSegmentation;
这是代码,但边界框根本不出现。如何更改此代码?
虽然您的代码看起来基本准确,但计算关键点位置的方式或用于确定是否显示整个身体的标准可能存在问题。
替换函数
drawWebcam
,并且检查所有必要的关键点的分数是否高于阈值
const drawWebcam = () => {
const video = webcamRef.current;
const net = netRef.current;
if (!video || !video.videoWidth || !video.videoHeight || !net) {
return;
}
const { videoWidth, videoHeight } = video;
if (canvas.width !== videoWidth || canvas.height !== videoHeight) {
canvas.width = videoWidth;
canvas.height = videoHeight;
}
context.clearRect(0, 0, videoWidth, videoHeight);
context.drawImage(video, 0, 0, videoWidth, videoHeight);
net.estimateSinglePose(video, { flipHorizontal: false })
.then(pose => {
const keypoints = pose.keypoints;
const isFullBodyVisible = keypoints.every(part => part.score >= 0.5);
if (isFullBodyVisible) {
const boundingBox = getBoundingBox(keypoints);
console.log('Bounding Box:', boundingBox);
drawBoundingBox(context, boundingBox);
}
drawPoints(context, keypoints, 0.5);
})
.finally(() => {
requestAnimationFrame(drawWebcam);
});
};