This commit is contained in:
zrguo
2025-04-03 14:44:56 +08:00
parent 9648300b18
commit e17e61f58e

View File

@@ -50,34 +50,38 @@ def create_openai_async_client(
client_configs: dict[str, Any] = None, client_configs: dict[str, Any] = None,
) -> AsyncOpenAI: ) -> AsyncOpenAI:
"""Create an AsyncOpenAI client with the given configuration. """Create an AsyncOpenAI client with the given configuration.
Args: Args:
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
Returns: Returns:
An AsyncOpenAI client instance. An AsyncOpenAI client instance.
""" """
if not api_key: if not api_key:
api_key = os.environ["OPENAI_API_KEY"] api_key = os.environ["OPENAI_API_KEY"]
default_headers = { default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
if client_configs is None: if client_configs is None:
client_configs = {} client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults # Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {**client_configs, "default_headers": default_headers, "api_key": api_key} merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
if base_url is not None: if base_url is not None:
merged_configs["base_url"] = base_url merged_configs["base_url"] = base_url
return AsyncOpenAI(**merged_configs) return AsyncOpenAI(**merged_configs)
@@ -99,7 +103,7 @@ async def openai_complete_if_cache(
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Complete a prompt using OpenAI's API with caching support. """Complete a prompt using OpenAI's API with caching support.
Args: Args:
model: The OpenAI model to use. model: The OpenAI model to use.
prompt: The prompt to complete. prompt: The prompt to complete.
@@ -114,10 +118,10 @@ async def openai_complete_if_cache(
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
- hashing_kv: Will be removed from kwargs before passing to OpenAI. - hashing_kv: Will be removed from kwargs before passing to OpenAI.
- keyword_extraction: Will be removed from kwargs before passing to OpenAI. - keyword_extraction: Will be removed from kwargs before passing to OpenAI.
Returns: Returns:
The completed text or an async iterator of text chunks if streaming. The completed text or an async iterator of text chunks if streaming.
Raises: Raises:
InvalidResponseError: If the response from OpenAI is invalid or empty. InvalidResponseError: If the response from OpenAI is invalid or empty.
APIConnectionError: If there is a connection error with the OpenAI API. APIConnectionError: If there is a connection error with the OpenAI API.
@@ -133,18 +137,16 @@ async def openai_complete_if_cache(
# Extract client configuration options # Extract client configuration options
client_configs = kwargs.pop("openai_client_configs", {}) client_configs = kwargs.pop("openai_client_configs", {})
# Create the OpenAI client # Create the OpenAI client
openai_async_client = create_openai_async_client( openai_async_client = create_openai_async_client(
api_key=api_key, api_key=api_key, base_url=base_url, client_configs=client_configs
base_url=base_url,
client_configs=client_configs
) )
# Remove special kwargs that shouldn't be passed to OpenAI # Remove special kwargs that shouldn't be passed to OpenAI
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)
# Prepare messages # Prepare messages
messages: list[dict[str, Any]] = [] messages: list[dict[str, Any]] = []
if system_prompt: if system_prompt:
@@ -337,7 +339,7 @@ async def openai_embed(
client_configs: dict[str, Any] = None, client_configs: dict[str, Any] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Generate embeddings for a list of texts using OpenAI's API. """Generate embeddings for a list of texts using OpenAI's API.
Args: Args:
texts: List of texts to embed. texts: List of texts to embed.
model: The OpenAI embedding model to use. model: The OpenAI embedding model to use.
@@ -346,10 +348,10 @@ async def openai_embed(
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
Returns: Returns:
A numpy array of embeddings, one per input text. A numpy array of embeddings, one per input text.
Raises: Raises:
APIConnectionError: If there is a connection error with the OpenAI API. APIConnectionError: If there is a connection error with the OpenAI API.
RateLimitError: If the OpenAI API rate limit is exceeded. RateLimitError: If the OpenAI API rate limit is exceeded.
@@ -357,11 +359,9 @@ async def openai_embed(
""" """
# Create the OpenAI client # Create the OpenAI client
openai_async_client = create_openai_async_client( openai_async_client = create_openai_async_client(
api_key=api_key, api_key=api_key, base_url=base_url, client_configs=client_configs
base_url=base_url,
client_configs=client_configs
) )
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"
) )