为 Donut OCR 模型创建 lambda 函数

问题描述 投票:0回答:1

我在aws中使用sagemaker开发了mlops pipline,管道和代码都设置正确并且每一步都成功,在此之前我在kaggle中处理模型,我找到了一个资源可以帮助我从收据中提取数据, 这就是我正在谈论的代码

image = example['image']
pixel_ = processor(image, return_tensors="pt").pixel_values
import torch

task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]

device = "cuda" if torch.cuda.is_available() else "cpu"

outputs = model.generate(pixel_.to(device),
                               decoder_input_ids=decoder_input_ids.to(device),
                               max_length=model_module.model.decoder.config.max_position_embeddings,
                               early_stopping=True,
                               pad_token_id=processor.tokenizer.pad_token_id,
                               eos_token_id=processor.tokenizer.eos_token_id,
                               use_cache=True,
                               num_beams=1,
                               bad_words_ids=[[processor.tokenizer.unk_token_id]],
                               return_dict_in_generate=True,
                               output_scores=True,)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
processor.token2json(sequence)

此代码工作正常,现在返回到 sagemaker mlops,现在已创建端点,但我不知道如何将此代码放入 lambda 函数中,因为我认为这是唯一的方法,事情就是这样代码正在使用处理器和型号。

我尝试制作 lambda 函数,但找不到正确的方法和解决方案,这没有给我任何结果,是否有任何建议如何实现它,或者关于如何使 lambda 函数工作的建议.

amazon-web-services pipeline amazon-sagemaker cicd mlops
1个回答
0
投票
import json
import os
import boto3  # Import for SageMaker runtime connection (optional)
import torch  # Assuming PyTorch model

# Assuming your package name is "receipt_parser" and it's uploaded to S3
s3 = boto3.client('s3')  # Optional for model/processor download (if not pre-loaded)
bucket_name = "your-s3-bucket-name"  # Replace with your bucket name
package_key = "receipt_parser-0.1.0.tar.gz"  # Replace with your package file name

def download_package(s3, bucket_name, package_key):
    """Downloads the pre-packaged model and processor from S3 (optional)."""
    local_path = "/tmp/receipt_parser.tar.gz"
    s3.download_file(bucket_name, package_key, local_path)
    os.system(f"tar -xf {local_path}")  # Extract the package

def lambda_handler(event, context):
    """Processes a receipt image and returns the extracted text."""

    # Download the package if necessary (uncomment if using optional download)
    # download_package(s3, bucket_name, package_key)

    # Load your pre-packaged model and processor (replace with your imports)
    from receipt_parser import processor, model  # Assuming your package structure

    # Get the receipt image data from the event (adjust based on your API design)
    receipt_data = event.get("receipt_image")  # Replace with actual data source

    # Preprocess the receipt image (if necessary)
    # ... (your image pre-processing logic)

    # Convert receipt data to a format suitable for your model (e.g., tensor)
    # pixel_ = processor(receipt_data, return_tensors="pt").pixel_values

    # Load the model and processor to the appropriate device (CPU or GPU)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Perform inference (replace with your actual model logic)
    outputs = model.generate(pixel_.to(device),
                             # ... (other model inference arguments)
                             )
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
    extracted_text = processor.token2json(sequence)

    return {
        "statusCode": 200,
        "body": json.dumps({"extracted_text": extracted_text})
    }

该代码包含一个可选的

download_package
函数,用于从 S3 检索模型和处理器(如果它们未预加载到 Lambda 环境中)。您可以将
bucket_name
package_key
等占位符以及模型/处理器导入替换为实际值。根据您的 API 设计(例如,通过事件对象或单独的输入机制)调整接收收据图像数据 (
receipt_data
) 的逻辑。

欲了解更多详细信息,我在其中启发了我的答案,您可以访问this

© www.soinside.com 2019 - 2024. All rights reserved.