我构建了一个现成的 CNN 并使用 AWS Step Functions 进行部署。我有这些端点的自定义函数:
def input_fn(data, content_type):
'''
take in image
'''
if content_type == 'application/json':
img = Image.open(io.BytesIO(data))
img_arr = np.array(img)
resized_arr = cv2.resize(img_arr, (img_size, img_size))
return resized_arr[None,...]
else:
raise RuntimeException("{} type is not supported by this endpoint.".format(content_type))
def model_fn():
'''
Return model
'''
client = boto3.client('s3')
client.download_file(Bucket=s3_bucket_name, Key='model/kcvg_cv_model.h5', Filename='kcvg_cv_model.h5')
model = tf.keras.saving.load_model('kcvg_cv_model.h5')
return model
def predict_fn(img_dir):
model = model_fn()
data = input_fn(img_dir)
prob = model.predict(data)
return np.argmax(prob, axis=-1)
当我运行这段代码时
from sagemaker.predictor import RealTimePredictor
from sagemaker.serializers import JSONSerializer
endpoint_name = 'odi-ds-belt-vision-cv-kcvg-endpoint-Final-Testing4'
# Read image into memory
payload = None
with open("117.jpg", 'rb') as f:
payload = f.read()
predictor = RealTimePredictor(endpoint_name = endpoint_name, sagemaker_session=sm_sess, serializer=JSONSerializer)
inference_response = predictor.predict(data=payload)
print (inference_response)
我收到以下错误
The class RealTimePredictor has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[14], line 11
8 payload = f.read()
10 predictor = RealTimePredictor(endpoint_name = endpoint_name, sagemaker_session=sm_sess, serializer=JSONSerializer)
---> 11 inference_response = predictor.predict(data=payload)
12 print (inference_response)
File ~/anaconda3/envs/odi-ds/lib/python3.9/site-packages/sagemaker/base_predictor.py:177, in Predictor.predict(self, data, initial_args, target_model, target_variant, inference_id, custom_attributes)
129 def predict(
130 self,
131 data,
(...)
136 custom_attributes=None,
137 ):
138 """Return the inference from the specified endpoint.
139
140 Args:
(...)
174 as is.
175 """
--> 177 request_args = self._create_request_args(
178 data,
179 initial_args,
180 target_model,
181 target_variant,
182 inference_id,
183 custom_attributes,
184 )
185 response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
186 return self._handle_response(response)
File ~/anaconda3/envs/odi-ds/lib/python3.9/site-packages/sagemaker/base_predictor.py:213, in Predictor._create_request_args(self, data, initial_args, target_model, target_variant, inference_id, custom_attributes)
207 args["EndpointName"] = self.endpoint_name
209 if "ContentType" not in args:
210 args["ContentType"] = (
211 self.content_type
212 if isinstance(self.content_type, str)
--> 213 else ", ".join(self.content_type)
214 )
216 if "Accept" not in args:
217 args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
TypeError: can only join an iterable
我猜我在
input_fn
中做错了什么,但我不太确定。可能是什么?
您遇到的问题似乎与
RealTimePredictor
类以及 content_type
的设置方式有关。错误消息 TypeError: can only join an iterable
表明 self.content_type
不是字符串,也不是可以连接的可迭代对象。
您可以考虑以下一些来调试和解决问题:
警告消息表明您使用的是旧版本的 SageMaker SDK,这可能是导致该问题的原因之一。将SDK更新到最新版本,并参考新的API文档查看问题是否依然存在。
content_type
您可以尝试在创建
content_type
对象时显式设置 RealTimePredictor
。
predictor = RealTimePredictor(endpoint_name=endpoint_name,
sagemaker_session=sm_sess,
serializer=JSONSerializer,
content_type='application/json')
错误消息提到
RealTimePredictor
已在sagemaker>=2
中重命名。如果您使用的是旧版本,请考虑升级,或者确保根据您的 SageMaker SDK 版本使用适当的类名称。
仔细检查您的
input_fn
、model_fn
和 predict_fn
函数,确保它们与 SageMaker 端点兼容。在您的 predict_fn
中,您的参数为 img_dir
,但它应该是 (input_data, model)
。
predict_fn
签名您的
predict_fn
函数似乎与预期的签名不匹配。 SageMaker predict_fn
预计将具有签名 predict_fn(input_object, model)
。请务必遵循此签名。
这是
predict_fn
的更正版本:
def predict_fn(input_data, model):
prob = model.predict(input_data)
return np.argmax(prob, axis=-1)
在自定义函数(
input_fn
、model_fn
、predict_fn
)中添加一些调试打印或日志,以检查它们是否按预期被调用以及它们接收的数据类型。
尝试这些步骤,看看是否可以解决您的问题。