Fix streaming problem for OpenAI

This commit is contained in:
yangdx
2025-05-09 15:54:54 +08:00
parent 3597239768
commit c2938a71a4

View File

@@ -177,7 +177,7 @@ async def openai_complete_if_cache(
logger.debug("===== Sending Query to LLM =====") logger.debug("===== Sending Query to LLM =====")
try: try:
async with openai_async_client: # Don't use async with context manager, use client directly
if "response_format" in kwargs: if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse( response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
@@ -188,17 +188,21 @@ async def openai_complete_if_cache(
) )
except APIConnectionError as e: except APIConnectionError as e:
logger.error(f"OpenAI API Connection Error: {e}") logger.error(f"OpenAI API Connection Error: {e}")
await openai_async_client.close() # Ensure client is closed
raise raise
except RateLimitError as e: except RateLimitError as e:
logger.error(f"OpenAI API Rate Limit Error: {e}") logger.error(f"OpenAI API Rate Limit Error: {e}")
await openai_async_client.close() # Ensure client is closed
raise raise
except APITimeoutError as e: except APITimeoutError as e:
logger.error(f"OpenAI API Timeout Error: {e}") logger.error(f"OpenAI API Timeout Error: {e}")
await openai_async_client.close() # Ensure client is closed
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}" f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}"
) )
await openai_async_client.close() # Ensure client is closed
raise raise
if hasattr(response, "__aiter__"): if hasattr(response, "__aiter__"):
@@ -243,6 +247,8 @@ async def openai_complete_if_cache(
logger.warning( logger.warning(
f"Failed to close stream response: {close_error}" f"Failed to close stream response: {close_error}"
) )
# Ensure client is closed in case of exception
await openai_async_client.close()
raise raise
finally: finally:
# Ensure resources are released even if no exception occurs # Ensure resources are released even if no exception occurs
@@ -258,10 +264,13 @@ async def openai_complete_if_cache(
logger.warning( logger.warning(
f"Failed to close stream response in finally block: {close_error}" f"Failed to close stream response in finally block: {close_error}"
) )
# Note: We don't close the client here for streaming responses
# The client will be closed by the caller after streaming is complete
return inner() return inner()
else: else:
try:
if ( if (
not response not response
or not response.choices or not response.choices
@@ -269,12 +278,14 @@ async def openai_complete_if_cache(
or not hasattr(response.choices[0].message, "content") or not hasattr(response.choices[0].message, "content")
): ):
logger.error("Invalid response from OpenAI API") logger.error("Invalid response from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Invalid response from OpenAI API") raise InvalidResponseError("Invalid response from OpenAI API")
content = response.choices[0].message.content content = response.choices[0].message.content
if not content or content.strip() == "": if not content or content.strip() == "":
logger.error("Received empty content from OpenAI API") logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API") raise InvalidResponseError("Received empty content from OpenAI API")
if r"\u" in content: if r"\u" in content:
@@ -283,7 +294,9 @@ async def openai_complete_if_cache(
if token_tracker and hasattr(response, "usage"): if token_tracker and hasattr(response, "usage"):
token_counts = { token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0), "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"completion_tokens": getattr(response.usage, "completion_tokens", 0), "completion_tokens": getattr(
response.usage, "completion_tokens", 0
),
"total_tokens": getattr(response.usage, "total_tokens", 0), "total_tokens": getattr(response.usage, "total_tokens", 0),
} }
token_tracker.add_usage(token_counts) token_tracker.add_usage(token_counts)
@@ -292,6 +305,9 @@ async def openai_complete_if_cache(
verbose_debug(f"Response: {response}") verbose_debug(f"Response: {response}")
return content return content
finally:
# Ensure client is closed in all cases for non-streaming responses
await openai_async_client.close()
async def openai_complete( async def openai_complete(