我目前正在使用 AWS sagemaker 生产 XGBoost 模型并进行实时推理。一段时间后,我想用一个接受更多数据训练的新模型来更新模型,并保持一切不变(例如相同的端点,相同的推理过程,所以除了模型本身之外实际上没有任何变化)
当前部署流程如下:
from sagemaker.xgboost.model import XGBoostModel
from sagemaker.xgboost.model import XGBoostPredictor
xgboost_model = XGBoostModel(
model_data = <S3 url>,
role = <sagemaker role>,
entry_point = 'inference.py',
source_dir = 'src',
code_location = <S3 url of other dependencies>
framework_version='1.5-1',
name = model_name)
xgboost_model.deploy(
instance_type='ml.c5.large',
initial_instance_count=1,
endpoint_name = model_name)
几周后我更新了模型,我想重新部署它。我知道
.deploy()
方法创建了一个端点和一个端点配置,因此它可以完成这一切。我不能简单地重新运行我的脚本,因为我会遇到错误。
在 sagemaker 的早期版本中,我可以使用传递给名为
.deploy()
的 update_endpoint = True
方法的额外参数来更新模型。在 sagemaker >=2.0 中,这是无操作。现在,在 sagemaker >= 2.0 中,我需要使用 documentation 中所述的预测器对象。所以我尝试以下方法:
predictor = XGBoostPredictor(model_name)
predictor.update_endpoint(model_name= model_name)
这实际上是根据新的端点配置更新端点。但是,我不知道它正在更新什么...我没有在上面两行代码中指定我们需要考虑在更多数据上训练的新
xgboost_model
...那么我在哪里告诉更新要进行更新的型号?
谢谢!
更新
我相信我需要查看其文档此处中所述的生产变体。然而,他们的整个教程基于 Amazon sdk for python (boto3),当我为每个模型变体有不同的入口点(例如不同的
inference.py
脚本或我想要打包的任何其他代码)时,它的工件很难管理。投入生产时的模型)
既然我找到了自己问题的答案,我会将其发布在这里,供遇到同样问题的人使用。
我最终使用
boto3
SDK 而不是 sagemaker
SDK 重新编写了我的部署脚本的代码(或两者的混合,正如一些文档所建议的,这甚至在 AWS 网站上可能已经过时)。
以下是完整脚本,展示了如何创建 sagemaker 模型对象、端点配置和首次部署模型的端点。此外,它还展示了如何使用更新的模型和新的模型工件(例如新的推理脚本)更新端点(这是我问题的目的)
如果您想引入自己的模型并通过
boto3
API 在 sagemaker 上的生产中安全地更新它,这里是执行所有 3 步的代码:
import boto3
import time
from datetime import datetime
from sagemaker import image_uris
from fileManager import * # this is a local script for helper functions
# name of zipped model and zipped inference code
CODE_TAR = 'your_inference_code_and_other_artifacts.tar.gz'
MODEL_TAR = 'your_saved_xgboost_model.tar.gz'
# sagemaker params
smClient = boto3.client('sagemaker')
smRole = <your_sagemaker_role>
bucket = sagemaker.Session().default_bucket()
# deploy algorithm
class Deployer:
def __init__(self, modelName, deployRetrained=False):
self.modelName=modelName
self.deployRetrained = deployRetrained
self.prefix = <S3_model_path_prefix>
def deploy(self):
'''
Main method to create a sagemaker model, create an endpoint configuration and deploy the model. If deployRetrained
param is set to True, this method will update an already existing endpoint.
'''
# define model name and endpoint name to be used for model deployment/update
model_name = self.modelName + <any_suffix>
endpoint_config_name = self.modelName + '-%s' %datetime.now().strftime('%Y-%m-%d-%HH%M')
endpoint_name = self.modelName
# deploy model for the first time
if not self.deployRetrained:
print('Deploying for the first time')
# here you should copy and zip the model dependencies that you may have (such as preprocessors, inference code, config code...)
# mine were zipped into the file called CODE_TAR
# upload model and model artifacts needed for inference to S3
uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)
# create sagemaker model and endpoint configuration
self.createSagemakerModel(model_name)
self.createEndpointConfig(endpoint_config_name, model_name)
# deploy model and wait while endpoint is being created
self.createEndpoint(endpoint_name, endpoint_config_name)
self.waitWhileCreating(endpoint_name)
# update model
else:
print('Updating existing model')
# upload model and model artifacts needed for inference (here the old ones are replaced)
# make sure to make a backup in S3 if you would like to keep the older models
# we replace the old ones and keep the same names to avoid having to recreate a sagemaker model with a different name for the update!
uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)
# create a new endpoint config that takes the new model
self.createEndpointConfig(endpoint_config_name, model_name)
# update endpoint
self.updateEndpoint(endpoint_name, endpoint_config_name)
# wait while endpoint updates then delete outdated endpoint config once it is InService
self.waitWhileCreating(endpoint_name)
self.deleteOutdatedEndpointConfig(model_name, endpoint_config_name)
def createSagemakerModel(self, model_name):
'''
Create a new sagemaker Model object with an xgboost container and an entry point for inference using boto3 API
'''
# Retrieve that inference image (container)
docker_container = image_uris.retrieve(region=region, framework='xgboost', version='1.5-1')
# Relative S3 path to pre-trained model to create S3 model URI
model_s3_key = f'{self.prefix}/'+ MODEL_TAR
# Combine bucket name, model file name, and relate S3 path to create S3 model URI
model_url = f's3://{bucket}/{model_s3_key}'
# S3 path to the necessary inference code
code_url = f's3://{bucket}/{self.prefix}/{CODE_TAR}'
# Create a sagemaker Model object with all its artifacts
smClient.create_model(
ModelName = model_name,
ExecutionRoleArn = smRole,
PrimaryContainer = {
'Image': docker_container,
'ModelDataUrl': model_url,
'Environment': {
'SAGEMAKER_PROGRAM': 'inference.py', #inference.py is at the root of my zipped CODE_TAR
'SAGEMAKER_SUBMIT_DIRECTORY': code_url,
}
}
)
def createEndpointConfig(self, endpoint_config_name, model_name):
'''
Create an endpoint configuration (only for boto3 sdk procedure) and set production variants parameters.
Each retraining procedure will induce a new variant name based on the endpoint configuration name.
'''
smClient.create_endpoint_config(
EndpointConfigName=endpoint_config_name,
ProductionVariants=[
{
'VariantName': endpoint_config_name,
'ModelName': model_name,
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1
}
]
)
def createEndpoint(self, endpoint_name, endpoint_config_name):
'''
Deploy the model to an endpoint
'''
smClient.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name)
def deleteOutdatedEndpointConfig(self, name_check, current_endpoint_config):
'''
Automatically detect and delete endpoint configurations that contain a string 'name_check'. This method can be used
after a retrain procedure to delete all previous endpoint configurations but keep the current one named 'current_endpoint_config'.
'''
# get a list of all available endpoint configurations
all_configs = smClient.list_endpoint_configs()['EndpointConfigs']
# loop over the names of endpoint configs
names_list = []
for config_dict in all_configs:
endpoint_config_name = config_dict['EndpointConfigName']
# get only endpoint configs that contain name_check in them and save names to a list
if name_check in endpoint_config_name:
names_list.append(endpoint_config_name)
# remove the current endpoint configuration from the list (we do not want to detele this one since it is live)
names_list.remove(current_endpoint_config)
for name in names_list:
try:
smClient.delete_endpoint_config(EndpointConfigName=name)
print('Deleted endpoint configuration for %s' %name)
except:
print('INFO : No endpoint configuration was found for %s' %endpoint_config_name)
def updateEndpoint(self, endpoint_name, endpoint_config_name):
'''
Update existing endpoint with a new retrained model
'''
smClient.update_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name,
RetainAllVariantProperties=True)
def waitWhileCreating(self, endpoint_name):
'''
While the endpoint is being created or updated sleep for 60 seconds.
'''
# wait while creating or updating endpoint
status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
print('Status: %s' %status)
while status != 'InService' and status !='Failed':
time.sleep(60)
status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
print('Status: %s' %status)
# in case of a deployment failure raise an error
if status == 'Failed':
raise ValueError('Endpoint failed to deploy')
if __name__=="__main__":
deployer = Deployer('MyDeployedModel', deployRetrained=True)
deployer.deploy()
最终评论:
sagemaker 文档提到了所有这些,但没有说明您可以为
create_model
方法提供“entry_point”以及用于推理依赖项的“source_dir”(例如标准化工件、推理脚本等)。可以按照PrimaryContainer
论证中所示来完成。
我的
fileManager.py
脚本仅包含制作 tar 文件、上传到我的 S3 路径以及从我的 S3 路径下载的基本功能。为了简化课程,我没有将它们包括在内。
方法
deleteOutdatedEndpointConfig
可能看起来有点矫枉过正,带有不必要的循环和检查,我这样做是因为我有多个端点配置需要处理,并且想要删除那些不活跃且包含字符串name_check
的端点配置(我不知道配置的确切名称,因为有一个日期时间后缀)。请随意简化它或将其全部删除。
希望有帮助。
在 model_name 中指定 SageMaker Model 对象的名称,您可以在其中指定 image_uri、model_data 等。
谢谢!这个答案非常有帮助