From c2938a71a4df801eccdd153495218ad035372a5b Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 9 May 2025 15:54:54 +0800 Subject: [PATCH] Fix streaming problem for OpenAI --- lightrag/llm/openai.py | 82 +++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index cd44bb93..690ac3f3 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -177,28 +177,32 @@ async def openai_complete_if_cache( logger.debug("===== Sending Query to LLM =====") try: - async with openai_async_client: - if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( - model=model, messages=messages, **kwargs - ) - else: - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) + # Don't use async with context manager, use client directly + if "response_format" in kwargs: + response = await openai_async_client.beta.chat.completions.parse( + model=model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) except APIConnectionError as e: logger.error(f"OpenAI API Connection Error: {e}") + await openai_async_client.close() # Ensure client is closed raise except RateLimitError as e: logger.error(f"OpenAI API Rate Limit Error: {e}") + await openai_async_client.close() # Ensure client is closed raise except APITimeoutError as e: logger.error(f"OpenAI API Timeout Error: {e}") + await openai_async_client.close() # Ensure client is closed raise except Exception as e: logger.error( f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}" ) + await openai_async_client.close() # Ensure client is closed raise if hasattr(response, "__aiter__"): @@ -243,6 +247,8 @@ async def openai_complete_if_cache( logger.warning( f"Failed to close stream response: {close_error}" ) + # Ensure client is closed in case of exception + await openai_async_client.close() raise finally: # Ensure resources are released even if no exception occurs @@ -258,40 +264,50 @@ async def openai_complete_if_cache( logger.warning( f"Failed to close stream response in finally block: {close_error}" ) + # Note: We don't close the client here for streaming responses + # The client will be closed by the caller after streaming is complete return inner() else: - if ( - not response - or not response.choices - or not hasattr(response.choices[0], "message") - or not hasattr(response.choices[0].message, "content") - ): - logger.error("Invalid response from OpenAI API") - raise InvalidResponseError("Invalid response from OpenAI API") + try: + if ( + not response + or not response.choices + or not hasattr(response.choices[0], "message") + or not hasattr(response.choices[0].message, "content") + ): + logger.error("Invalid response from OpenAI API") + await openai_async_client.close() # Ensure client is closed + raise InvalidResponseError("Invalid response from OpenAI API") - content = response.choices[0].message.content + content = response.choices[0].message.content - if not content or content.strip() == "": - logger.error("Received empty content from OpenAI API") - raise InvalidResponseError("Received empty content from OpenAI API") + if not content or content.strip() == "": + logger.error("Received empty content from OpenAI API") + await openai_async_client.close() # Ensure client is closed + raise InvalidResponseError("Received empty content from OpenAI API") - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) - if token_tracker and hasattr(response, "usage"): - token_counts = { - "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), - "completion_tokens": getattr(response.usage, "completion_tokens", 0), - "total_tokens": getattr(response.usage, "total_tokens", 0), - } - token_tracker.add_usage(token_counts) + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "completion_tokens": getattr( + response.usage, "completion_tokens", 0 + ), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) - logger.debug(f"Response content len: {len(content)}") - verbose_debug(f"Response: {response}") + logger.debug(f"Response content len: {len(content)}") + verbose_debug(f"Response: {response}") - return content + return content + finally: + # Ensure client is closed in all cases for non-streaming responses + await openai_async_client.close() async def openai_complete(