如何对 Sagemaker 端点进行预测? (JSON 错误)

问题描述 投票:0回答:2

我已经部署了 sagemaker 端点,并且现在想在端点上运行预测。端点代表 Sagemaker 管道和模型。我按照教程这里进行操作。我设置预测器并进行预测的代码如下:

from sagemaker.predictor import Predictor
predictor = Predictor(endpoint_name=endpoint_name)
data_df = data_df.drop("LABEL_NAME", axis=1)
pred_count = 1
payload = data_df.iloc[:pred_count].to_string(header=False, index=False).replace("  ", ",")
p = predictor.predict(payload, initial_args={"ContentType": "text/csv"})

这段代码几乎就是他们在我链接的示例中显示的内容,对我来说很有意义。我的管道的 preprocess.py 代码包含以下函数(尽管不确定它们是否相关):

def input_fn(input_data, content_type):
    print("BAHHHHHH")
    if content_type == "text/csv":
        # Read the raw input data as CSV.
        df = pd.read_csv(StringIO(input_data), header=None)
        return df
    else:
        raise ValueError("{} not supported by script!".format(content_type))

def output_fn(prediction, accept):
    print("BAHHHHHH")
    if accept == "application/json":
        instances = []
        for row in prediction.tolist():
            instances.append(row)
        json_output = {"instances": instances}

        return worker.Response(json.dumps(json_output), mimetype=accept)
    elif accept == "text/csv":
        return worker.Response(encoders.encode(prediction, accept), mimetype=accept)
    else:
        raise RuntimeException("{} accept type is not supported by this script.".format(accept))

def predict_fn(input_data, model):
    print("BAHHHHHH")
    features = model.transform(input_data)
    return features

def model_fn(model_dir):
    print("BAHHHHHH")
    """Deserialize fitted model"""
    preprocessor = joblib.load(os.path.join(model_dir, "model.joblib"))
    return preprocessor

运行 Predictor.predict() 方法时,出现以下错误:

botocore.errorfactory.ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{
    "error": "JSON Parse error: Missing a comma or ']' after an array element. at offset: 16"

我在将有效负载变量传递给预测方法之前打印了它,它看起来像这样(我截断了它,因为它很长,但这应该足以看到它是什么样的:

0 999.105105 888.607813 6.0 1 los angeles 2431.666667 1.0 NaN 1177.813623 1.076833e+06 los angeles$1$6 0 60376511012 0.0 0.0 0.0 0.0 0.0 0.0 ............

错误消息还提供了一个可查看更多信息的 URL。它是端点的云监视日志。查看这些日志,我没有看到任何额外的信息,只是一个 400 错误,除了 400 错误之外没有任何其他信息。

因此,我传入的数据格式显然存在一些问题。 input_fn、output_fn、predict_fn 和 model_fn 方法在方法开头都有一个打印语句,但这些语句都没有显示在日志中,因此我认为其中任何一个都没有达到。

我做错了什么?

python amazon-web-services prediction amazon-sagemaker
2个回答
0
投票

我建议进行测试,确保您可以在本地反序列化您的有效负载,以确认有效负载和/或序列化 CSV 并将其加载回内存的代码没有问题。

此外,我建议使用简单的有效负载进行测试并记下行为。


0
投票

这里有两种可能性:

  1. 发送到端点的输入格式不正确。输入在发送到 input_fn 之前会经过以下转换:

    input_data = data[0].get("body") request_property = context.request_processor[0].get_request_properties() content_type = utils.retrieve_content_type_header(request_property) 接受 = request_property.get("接受") 或 request_property.get("接受") 如果不接受或接受== content_types.ANY: 接受 = content_types.JSON 如果 content_type 在 content_types.UTF8_TYPES 中: input_data = input_data.decode("utf-8")

  2. model.transform(input_data) 出现错误的可能性较小。您可以尝试在 jupyter Notebook 中调试它。似乎转换方法需要 JSON,但它没有得到它。

© www.soinside.com 2019 - 2024. All rights reserved.