TypeError:只能连接可迭代(测试调用 SageMaker 端点进行计算机视觉分类)

问题描述 投票:0回答:1

我构建了一个现成的 CNN 并使用 AWS Step Functions 进行部署。我有这些端点的自定义函数:

def input_fn(data, content_type):
    '''
    take in image
    '''
    if content_type == 'application/json':
        img = Image.open(io.BytesIO(data))
        img_arr = np.array(img)
        resized_arr = cv2.resize(img_arr, (img_size, img_size))
        return resized_arr[None,...]
    else:
        raise RuntimeException("{} type is not supported by this endpoint.".format(content_type))



def model_fn():
    '''
    Return model
    '''

    client = boto3.client('s3')
    client.download_file(Bucket=s3_bucket_name, Key='model/kcvg_cv_model.h5', Filename='kcvg_cv_model.h5')
    model = tf.keras.saving.load_model('kcvg_cv_model.h5')

    return model

def predict_fn(img_dir):
    model = model_fn()
    data = input_fn(img_dir)
    prob = model.predict(data)
    return np.argmax(prob, axis=-1)

当我运行这段代码时

from sagemaker.predictor import RealTimePredictor
from sagemaker.serializers import JSONSerializer

endpoint_name = 'odi-ds-belt-vision-cv-kcvg-endpoint-Final-Testing4'

# Read image into memory
payload = None
with open("117.jpg", 'rb') as f:
    payload = f.read()

predictor = RealTimePredictor(endpoint_name = endpoint_name, sagemaker_session=sm_sess, serializer=JSONSerializer)
inference_response = predictor.predict(data=payload)
print (inference_response)

我收到以下错误

The class RealTimePredictor has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[14], line 11
      8     payload = f.read()
     10 predictor = RealTimePredictor(endpoint_name = endpoint_name, sagemaker_session=sm_sess, serializer=JSONSerializer)
---> 11 inference_response = predictor.predict(data=payload)
     12 print (inference_response)

    File ~/anaconda3/envs/odi-ds/lib/python3.9/site-packages/sagemaker/base_predictor.py:177, in Predictor.predict(self, data, initial_args, target_model, target_variant, inference_id, custom_attributes)
        129 def predict(
        130     self,
        131     data,
       (...)
        136     custom_attributes=None,
        137 ):
        138     """Return the inference from the specified endpoint.
        139
        140     Args:
       (...)
        174             as is.
        175     """
    --> 177     request_args = self._create_request_args(
        178         data,
        179         initial_args,
        180         target_model,
        181         target_variant,
        182         inference_id,
        183         custom_attributes,
        184     )
        185     response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
        186     return self._handle_response(response)

    File ~/anaconda3/envs/odi-ds/lib/python3.9/site-packages/sagemaker/base_predictor.py:213, in Predictor._create_request_args(self, data, initial_args, target_model, target_variant, inference_id, custom_attributes)
        207     args["EndpointName"] = self.endpoint_name
        209 if "ContentType" not in args:
        210     args["ContentType"] = (
        211         self.content_type
        212         if isinstance(self.content_type, str)
    --> 213         else ", ".join(self.content_type)
        214     )
        216 if "Accept" not in args:
        217     args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept)

    TypeError: can only join an iterable

我猜我在

input_fn
中做错了什么,但我不太确定。可能是什么?

python typeerror amazon-sagemaker
1个回答
-1
投票

您遇到的问题似乎与

RealTimePredictor
类以及
content_type
的设置方式有关。错误消息
TypeError: can only join an iterable
表明
self.content_type
不是字符串,也不是可以连接的可迭代对象。

您可以考虑以下一些来调试和解决问题:

1.更新到最新的 SageMaker SDK 版本

警告消息表明您使用的是旧版本的 SageMaker SDK,这可能是导致该问题的原因之一。将SDK更新到最新版本,并参考新的API文档查看问题是否依然存在。

2.明确设置
content_type

您可以尝试在创建

content_type
对象时显式设置
RealTimePredictor

predictor = RealTimePredictor(endpoint_name=endpoint_name, 
                              sagemaker_session=sm_sess, 
                              serializer=JSONSerializer, 
                              content_type='application/json')

3.检查 SageMaker 版本

错误消息提到

RealTimePredictor
已在
sagemaker>=2
中重命名。如果您使用的是旧版本,请考虑升级,或者确保根据您的 SageMaker SDK 版本使用适当的类名称。

4.验证自定义功能

仔细检查您的

input_fn
model_fn
predict_fn
函数,确保它们与 SageMaker 端点兼容。在您的
predict_fn
中,您的参数为
img_dir
,但它应该是
(input_data, model)

5.检查
predict_fn
签名

您的

predict_fn
函数似乎与预期的签名不匹配。 SageMaker
predict_fn
预计将具有签名
predict_fn(input_object, model)
。请务必遵循此签名。

这是

predict_fn
的更正版本:

def predict_fn(input_data, model):
    prob = model.predict(input_data)
    return np.argmax(prob, axis=-1)

6.调试

在自定义函数(

input_fn
model_fn
predict_fn
)中添加一些调试打印或日志,以检查它们是否按预期被调用以及它们接收的数据类型。

尝试这些步骤,看看是否可以解决您的问题。

© www.soinside.com 2019 - 2024. All rights reserved.