使用flask部署深度学习模型时出错

问题描述 投票:0回答:1

我正在使用图像制作一个关于狗品种预测的项目,我想将其部署在我使用 Flask 的网站上,我从 chatgpt 和多个 YouTube 教程中获得了帮助,因为这是我第一次使用 Flask,但是当我运行我的 app.py 文件时我 cmd 它给了我这个错误

# Load the pre-trained TensorFlow model
model = tf.keras.models.load_model("model path")
# Function to preprocess the image

def preprocess_image(image_path):
    # Open the image using Pillow
    img = Image.open(image_path)
    
    # Resize image to match model's expected input shape
    img = img.resize((224, 224))
    
    # Convert image to numpy array
    img_array = np.array(img)
    
    # Normalize pixel values to range [0, 1]
    img_array = img_array / 255.0
    
    # Expand dimensions to match model's input shape (add batch dimension)
    img_array = np.expand_dims(img_array, axis=0)
    
    return img_array


# Route to handle image upload and prediction
@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return 'No file part'
    
    file = request.files['file']
    
    # Check if the file is empty
    if file.filename == '':
        return 'No selected file'
    
    # Check if the file is an image
    if file and allowed_file(file.filename):
        # Read the image file
        img = Image.open(file)
        # Preprocess the image
        img_array = preprocess_image(img)
        # Perform prediction using the loaded TensorFlow model
        prediction = model.predict(img_array)
        # Process prediction result (you may need to adjust this based on your model)
        # For demonstration, let's assume the model outputs a class index
        predicted_class_index = np.argmax(prediction)
        # Convert class index to a human-readable prediction
        class_labels = ['Class 0', 'Class 1', 'Class 2']  # Replace with your class labels
        prediction_result = class_labels[predicted_class_index]
        return f'The predicted class is: {prediction_result}'
    
    return 'Invalid file format'

# Function to check if the file extension is allowed
def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}

我收到这个错误 enter image description here

请帮助我,我收到此错误,并且我无法理解该错误是在 Flask 代码中还是在我的 DL 模型代码中

flask deep-learning deployment command google-colaboratory
1个回答
0
投票

给定错误的解决方案是这里。您可以通过传递custom_objects来解决问题:

import tensorflow_hub as hub

model = tf.keras.models.load_model(
       "model path",
       custom_objects={'KerasLayer':hub.KerasLayer}
)

您还需要检查

preprocess_image()
输入。看来你向它发送了一个文件(不是文件路径)。

© www.soinside.com 2019 - 2024. All rights reserved.