我已经部署了 sagemaker 端点,并且现在想在端点上运行预测。端点代表 Sagemaker 管道和模型。我按照教程这里进行操作。我设置预测器并进行预测的代码如下:
from sagemaker.predictor import Predictor
predictor = Predictor(endpoint_name=endpoint_name)
data_df = data_df.drop("LABEL_NAME", axis=1)
pred_count = 1
payload = data_df.iloc[:pred_count].to_string(header=False, index=False).replace(" ", ",")
p = predictor.predict(payload, initial_args={"ContentType": "text/csv"})
这段代码几乎就是他们在我链接的示例中显示的内容,对我来说很有意义。我的管道的 preprocess.py 代码包含以下函数(尽管不确定它们是否相关):
def input_fn(input_data, content_type):
print("BAHHHHHH")
if content_type == "text/csv":
# Read the raw input data as CSV.
df = pd.read_csv(StringIO(input_data), header=None)
return df
else:
raise ValueError("{} not supported by script!".format(content_type))
def output_fn(prediction, accept):
print("BAHHHHHH")
if accept == "application/json":
instances = []
for row in prediction.tolist():
instances.append(row)
json_output = {"instances": instances}
return worker.Response(json.dumps(json_output), mimetype=accept)
elif accept == "text/csv":
return worker.Response(encoders.encode(prediction, accept), mimetype=accept)
else:
raise RuntimeException("{} accept type is not supported by this script.".format(accept))
def predict_fn(input_data, model):
print("BAHHHHHH")
features = model.transform(input_data)
return features
def model_fn(model_dir):
print("BAHHHHHH")
"""Deserialize fitted model"""
preprocessor = joblib.load(os.path.join(model_dir, "model.joblib"))
return preprocessor
运行 Predictor.predict() 方法时,出现以下错误:
botocore.errorfactory.ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{
"error": "JSON Parse error: Missing a comma or ']' after an array element. at offset: 16"
我在将有效负载变量传递给预测方法之前打印了它,它看起来像这样(我截断了它,因为它很长,但这应该足以看到它是什么样的:
0 999.105105 888.607813 6.0 1 los angeles 2431.666667 1.0 NaN 1177.813623 1.076833e+06 los angeles$1$6 0 60376511012 0.0 0.0 0.0 0.0 0.0 0.0 ............
错误消息还提供了一个可查看更多信息的 URL。它是端点的云监视日志。查看这些日志,我没有看到任何额外的信息,只是一个 400 错误,除了 400 错误之外没有任何其他信息。
因此,我传入的数据格式显然存在一些问题。 input_fn、output_fn、predict_fn 和 model_fn 方法在方法开头都有一个打印语句,但这些语句都没有显示在日志中,因此我认为其中任何一个都没有达到。
我做错了什么?
我建议进行测试,确保您可以在本地反序列化您的有效负载,以确认有效负载和/或序列化 CSV 并将其加载回内存的代码没有问题。
此外,我建议使用简单的有效负载进行测试并记下行为。
这里有两种可能性:
发送到端点的输入格式不正确。输入在发送到 input_fn 之前会经过以下转换:
input_data = data[0].get("body") request_property = context.request_processor[0].get_request_properties() content_type = utils.retrieve_content_type_header(request_property) 接受 = request_property.get("接受") 或 request_property.get("接受") 如果不接受或接受== content_types.ANY: 接受 = content_types.JSON 如果 content_type 在 content_types.UTF8_TYPES 中: input_data = input_data.decode("utf-8")
model.transform(input_data) 出现错误的可能性较小。您可以尝试在 jupyter Notebook 中调试它。似乎转换方法需要 JSON,但它没有得到它。