diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index f29d10c3..d9939809 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -50,34 +50,38 @@ def create_openai_async_client( client_configs: dict[str, Any] = None, ) -> AsyncOpenAI: """Create an AsyncOpenAI client with the given configuration. - + Args: api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). - + Returns: An AsyncOpenAI client instance. """ if not api_key: api_key = os.environ["OPENAI_API_KEY"] - + default_headers = { "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "Content-Type": "application/json", } - + if client_configs is None: client_configs = {} - + # Create a merged config dict with precedence: explicit params > client_configs > defaults - merged_configs = {**client_configs, "default_headers": default_headers, "api_key": api_key} - + merged_configs = { + **client_configs, + "default_headers": default_headers, + "api_key": api_key, + } + if base_url is not None: merged_configs["base_url"] = base_url - + return AsyncOpenAI(**merged_configs) @@ -99,7 +103,7 @@ async def openai_complete_if_cache( **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support. - + Args: model: The OpenAI model to use. prompt: The prompt to complete. @@ -114,10 +118,10 @@ async def openai_complete_if_cache( explicit parameters (api_key, base_url). - hashing_kv: Will be removed from kwargs before passing to OpenAI. - keyword_extraction: Will be removed from kwargs before passing to OpenAI. - + Returns: The completed text or an async iterator of text chunks if streaming. - + Raises: InvalidResponseError: If the response from OpenAI is invalid or empty. APIConnectionError: If there is a connection error with the OpenAI API. @@ -133,18 +137,16 @@ async def openai_complete_if_cache( # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) - + # Create the OpenAI client openai_async_client = create_openai_async_client( - api_key=api_key, - base_url=base_url, - client_configs=client_configs + api_key=api_key, base_url=base_url, client_configs=client_configs ) # Remove special kwargs that shouldn't be passed to OpenAI kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) - + # Prepare messages messages: list[dict[str, Any]] = [] if system_prompt: @@ -337,7 +339,7 @@ async def openai_embed( client_configs: dict[str, Any] = None, ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. - + Args: texts: List of texts to embed. model: The OpenAI embedding model to use. @@ -346,10 +348,10 @@ async def openai_embed( client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). - + Returns: A numpy array of embeddings, one per input text. - + Raises: APIConnectionError: If there is a connection error with the OpenAI API. RateLimitError: If the OpenAI API rate limit is exceeded. @@ -357,11 +359,9 @@ async def openai_embed( """ # Create the OpenAI client openai_async_client = create_openai_async_client( - api_key=api_key, - base_url=base_url, - client_configs=client_configs + api_key=api_key, base_url=base_url, client_configs=client_configs ) - + response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" )