Django Rest Framework 异步错误:“'async_generator' 对象不可迭代”

问题描述 投票:0回答:1

我正在开发一个带有异步视图的 Django Rest Framework 项目。我有一个用于流式响应的端点,这是我的代码:

from adrf.views import APIView as View

class ChatAPI(View):
    permission_classes = [IsAuthenticated]

    async def process_dataframe(self, file_urls):
        data_frames = []
        for file_url in file_urls:
            if file_url.endswith('.csv'):
                df = pd.read_csv(file_url)
                data_frames.append(df)
            elif file_url.endswith(('.xls', '.xlsx')):
                xls = pd.ExcelFile(file_url)
                sheet_names = xls.sheet_names
                for sheet_name in sheet_names:
                    df = pd.read_excel(file_url, sheet_name=sheet_name)
                    data_frames.append(df)
        return data_frames

    async def get_cached_data(self, cache_key):
        return cache.get(cache_key)

    async def check_file_extension(self, file_ext, csv_agent_type):
        print(f"Checking file extension: {file_ext}")
        return file_ext in csv_agent_type

    async def check_all_file_extensions(self, file_extensions, csv_agent_type):
        tasks = [self.check_file_extension(
            file_ext, csv_agent_type) for file_ext in file_extensions]
        print(tasks, "tasks")

        results = await asyncio.gather(*tasks)

        print(f"Results: {results}")
        return all(results)

    async def post(self, request, thread_id, user_id):
        try:
            a = time.time()
            print("Start===========")
            csv_agent_type = ['csv', 'xls', 'xlsx']
            file_ids = request.GET.get('file_ids', [])
            message = request.data.get('message')
            chat_history = ""
            streamed_text = ""
            image_url = None

            # Use sync_to_async here
            generate_graph = await ais_graph_required(message)

            print(generate_graph, "....")
            print(thread_id, '==========thread_id')

            # Use sync_to_async here
            thread = await sync_to_async(Thread.objects.get)(id=thread_id)
            file_ids = await sync_to_async(lambda: list(
                thread.files.all().values_list('id', flat=True)))()

            file_types = await sync_to_async(lambda: list(
                User_Files_New.objects.filter(id__in=file_ids).values_list('file_name', flat=True)))()
  # Use sync_to_async here
            file_extensions = [file_name.split(
                ".")[-1] for file_name in file_types]

            cache_key = f"kibee_agent_cache_{user_id}_{thread_id}"
            print(cache_key)
            cached_data = await self.get_cached_data(cache_key)
            print("got error here: retriving the cached data")

            if cached_data:
                chat_history = cached_data['chat_history']
            print("chat history loaded")

            result = await self.check_all_file_extensions(file_extensions, csv_agent_type)
            print(result, "result")

            if result:
                # Your code here
                print("in if")
                files = await sync_to_async(User_Files_New.objects.filter)(
                    id__in=file_ids)
                print("files")     # Use sync_to_async here
                indexes = await sync_to_async(lambda: list(
                    UserIndexes.objects.filter(id__in=[file.user_index_id for file in files]).values_list(
                        'file_name', flat=True)
                ))()
  # Use sync_to_async here
                print(indexes, "index")
                file_urls = [
                    f"{os.getenv('HOST_URL')}{index}" for index in indexes]

                data_frames = await self.process_dataframe(file_urls)

                print("df loaded")

                agent = await sync_to_async(create_pandas_dataframe_agent)(
                    ChatOpenAI(temperature=0, verbose=True,
                               model=os.getenv("GPT_MODEL"),
                               streaming=True),
                    data_frames,
                    verbose=True,
                    streaming=True,
                    agent_type=AgentType.OPENAI_FUNCTIONS,
                    handle_parsing_errors=True,
                    max_iterations=50,
                    return_intermediate_steps=True,
                    agent_executor_kwargs={
                        "handle_parsing_errors": True,
                    }
                )
                e = time.time()
                print("basic load", e-a)

                if generate_graph:
                    if not os.path.exists(settings.MEDIA_ROOT+f"/{request.user.email}/plots/"):
                        os.makedirs(settings.MEDIA_ROOT +
                                    f"/{request.user.email}/plots/")
                    os.path.join(settings.MEDIA_ROOT,
                                 f"/{request.user.email}/plots/")
                    plot_dir = settings.MEDIA_ROOT + \
                        f"/{request.user.email}/plots/"
                    prompt = await sync_to_async(get_prompt)(
                        chat_history, message, plot_dir, generate_graph)  # Use sync_to_async here
                else:
                    # Use sync_to_async here
                    prompt = await sync_to_async(get_prompt)(chat_history, message)
                f = time.time()
                # print(prompt)
                # print("prompt load", f-e)

                async def generate_stream():
                    print("...........>", prompt)
                    async for chunk in agent.astream({"input": prompt}):
                        res = ""
                        if "actions" in chunk:
                            for action in chunk["actions"]:
                                res += f"Calling Tool: `{action.tool}` with input `{action.tool_input}`\n"
                        elif "steps" in chunk:
                            for step in chunk["steps"]:
                                res += f"Tool Result: `{step.observation}`\n"
                        elif "output" in chunk:
                            res += f'Final Output: {chunk["output"]}\n'
                        else:
                            raise ValueError()

                        # Send each print statement as a separate chunk in the streaming response
                        yield json.dumps({"res": res})
                response = await generate_stream()
                return StreamingHttpResponse(
                    response, content_type='text/event-stream')

        except Exception as e:
            return JsonResponse({"error": str(e)}, status=500)

我收到以下错误

{"error": "'async_generator' object is not iterable"} 

我有一个用于流响应的 Django Rest Framework (DRF) 异步端点。我正在使用 Django Rest Framework、langchain 和 asyncio。在我的代码片段中,我尝试使用异步生成器生成流

我期望异步生成器是可迭代的,并为流响应提供数据块。目标是迭代生成器并在generate_stream函数中生成JSON格式的块。

当尝试在 agent.astream({"input":prompt}): 中使用 async for chunk 迭代异步生成器时,特别会发生此错误。

我已经查看了所使用的库(Django Rest Framework、langchain)的文档,并在我的代码中尝试了一些变体,但问题仍然存在。

django asynchronous django-rest-framework langchain agent
1个回答
0
投票

使用

async_gen.__anext__()
手动推进异步发电机。请记住,这种方法特定于您正在使用的库,您可能需要查阅该库的文档或源代码,以了解迭代其异步生成器的正确方法。

async def generate_stream():
    print("...........>", prompt)
    async_gen = agent.astream({"input": prompt})
    
    try:
        while True:
            chunk = await async_gen.__anext__()
            
            res = ""
            if "actions" in chunk:
                for action in chunk["actions"]:
                    res += f"Calling Tool: `{action.tool}` with input `{action.tool_input}`\n"
            elif "steps" in chunk:
                for step in chunk["steps"]:
                    res += f"Tool Result: `{step.observation}`\n"
            elif "output" in chunk:
                res += f'Final Output: {chunk["output"]}\n'
            else:
                raise ValueError()

            # Send each print statement as a separate chunk in the streaming response
            yield json.dumps({"res": res})
    except StopAsyncIteration:
        pass

response = await generate_stream()
return StreamingHttpResponse(response, content_type='text/event-stream')

您遇到的错误“'async_generator'对象不可迭代”表明agent.astream({“input”:prompt})没有返回可迭代对象。在异步编程中,并非所有异步生成器都可以使用标准 async for 循环语法直接迭代。有些可能需要使用异步迭代方法进行迭代。

如果这不能解决问题,您可能需要检查

langchain
库的文档或联系其维护人员以获取有关正确迭代
agent.astream({"input": prompt})
返回的异步生成器的指导。

© www.soinside.com 2019 - 2024. All rights reserved.