From d45dc14069b9d3555236ac1b33fa897f71516588 Mon Sep 17 00:00:00 2001 From: Shane Walker Date: Thu, 27 Mar 2025 15:39:39 -0700 Subject: [PATCH] feat(openai): add client configuration support to OpenAI integration Add support for custom client configurations in the OpenAI integration, allowing for more flexible configuration of the AsyncOpenAI client. This includes: - Create a reusable helper function `create_openai_async_client` - Add proper documentation for client configuration options - Ensure consistent parameter precedence across the codebase - Update the embedding function to support client configurations - Add example script demonstrating custom client configuration usage The changes maintain backward compatibility while providing a cleaner and more maintainable approach to configuring OpenAI clients. --- lightrag/llm/openai.py | 127 ++++++++++++++++++++++++++++++++--------- 1 file changed, 101 insertions(+), 26 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 70aa0ceb..394c4370 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -44,6 +44,43 @@ class InvalidResponseError(Exception): pass +def create_openai_async_client( + api_key: str | None = None, + base_url: str | None = None, + client_configs: dict[str, Any] = None, +) -> AsyncOpenAI: + """Create an AsyncOpenAI client with the given configuration. + + Args: + api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. + client_configs: Additional configuration options for the AsyncOpenAI client. + These will override any default configurations but will be overridden by + explicit parameters (api_key, base_url). + + Returns: + An AsyncOpenAI client instance. + """ + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] + + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } + + if client_configs is None: + client_configs = {} + + # Create a merged config dict with precedence: explicit params > client_configs > defaults + merged_configs = {**client_configs, "default_headers": default_headers, "api_key": api_key} + + if base_url is not None: + merged_configs["base_url"] = base_url + + return AsyncOpenAI(**merged_configs) + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -60,29 +97,54 @@ async def openai_complete_if_cache( api_key: str | None = None, **kwargs: Any, ) -> str: + """Complete a prompt using OpenAI's API with caching support. + + Args: + model: The OpenAI model to use. + prompt: The prompt to complete. + system_prompt: Optional system prompt to include. + history_messages: Optional list of previous messages in the conversation. + base_url: Optional base URL for the OpenAI API. + api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + **kwargs: Additional keyword arguments to pass to the OpenAI API. + Special kwargs: + - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. + These will be passed to the client constructor but will be overridden by + explicit parameters (api_key, base_url). + - hashing_kv: Will be removed from kwargs before passing to OpenAI. + - keyword_extraction: Will be removed from kwargs before passing to OpenAI. + + Returns: + The completed text or an async iterator of text chunks if streaming. + + Raises: + InvalidResponseError: If the response from OpenAI is invalid or empty. + APIConnectionError: If there is a connection error with the OpenAI API. + RateLimitError: If the OpenAI API rate limit is exceeded. + APITimeoutError: If the OpenAI API request times out. + """ if history_messages is None: history_messages = [] - if not api_key: - api_key = os.environ["OPENAI_API_KEY"] - - default_headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json", - } # Set openai logger level to INFO when VERBOSE_DEBUG is off if not VERBOSE_DEBUG and logger.level == logging.DEBUG: logging.getLogger("openai").setLevel(logging.INFO) - openai_async_client = ( - AsyncOpenAI(default_headers=default_headers, api_key=api_key) - if base_url is None - else AsyncOpenAI( - base_url=base_url, default_headers=default_headers, api_key=api_key - ) + # Extract client configuration options + client_configs = kwargs.pop("openai_client_configs", {}) + + # Create the OpenAI client + openai_async_client = create_openai_async_client( + api_key=api_key, + base_url=base_url, + client_configs=client_configs ) + + # Remove special kwargs that shouldn't be passed to OpenAI kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) + + # Prepare messages messages: list[dict[str, Any]] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -257,21 +319,34 @@ async def openai_embed( model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None, + client_configs: dict[str, Any] = None, ) -> np.ndarray: - if not api_key: - api_key = os.environ["OPENAI_API_KEY"] - - default_headers = { - "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json", - } - openai_async_client = ( - AsyncOpenAI(default_headers=default_headers, api_key=api_key) - if base_url is None - else AsyncOpenAI( - base_url=base_url, default_headers=default_headers, api_key=api_key - ) + """Generate embeddings for a list of texts using OpenAI's API. + + Args: + texts: List of texts to embed. + model: The OpenAI embedding model to use. + base_url: Optional base URL for the OpenAI API. + api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + client_configs: Additional configuration options for the AsyncOpenAI client. + These will override any default configurations but will be overridden by + explicit parameters (api_key, base_url). + + Returns: + A numpy array of embeddings, one per input text. + + Raises: + APIConnectionError: If there is a connection error with the OpenAI API. + RateLimitError: If the OpenAI API rate limit is exceeded. + APITimeoutError: If the OpenAI API request times out. + """ + # Create the OpenAI client + openai_async_client = create_openai_async_client( + api_key=api_key, + base_url=base_url, + client_configs=client_configs ) + response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" )