Merge branch 'main' into clear-doc
This commit is contained in:
@@ -44,6 +44,47 @@ class InvalidResponseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def create_openai_async_client(
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
client_configs: dict[str, Any] = None,
|
||||
) -> AsyncOpenAI:
|
||||
"""Create an AsyncOpenAI client with the given configuration.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
||||
Returns:
|
||||
An AsyncOpenAI client instance.
|
||||
"""
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if client_configs is None:
|
||||
client_configs = {}
|
||||
|
||||
# Create a merged config dict with precedence: explicit params > client_configs > defaults
|
||||
merged_configs = {
|
||||
**client_configs,
|
||||
"default_headers": default_headers,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if base_url is not None:
|
||||
merged_configs["base_url"] = base_url
|
||||
|
||||
return AsyncOpenAI(**merged_configs)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@@ -61,29 +102,52 @@ async def openai_complete_if_cache(
|
||||
token_tracker: Any | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Complete a prompt using OpenAI's API with caching support.
|
||||
|
||||
Args:
|
||||
model: The OpenAI model to use.
|
||||
prompt: The prompt to complete.
|
||||
system_prompt: Optional system prompt to include.
|
||||
history_messages: Optional list of previous messages in the conversation.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
||||
Special kwargs:
|
||||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||
These will be passed to the client constructor but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
||||
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
||||
|
||||
Returns:
|
||||
The completed text or an async iterator of text chunks if streaming.
|
||||
|
||||
Raises:
|
||||
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
||||
APIConnectionError: If there is a connection error with the OpenAI API.
|
||||
RateLimitError: If the OpenAI API rate limit is exceeded.
|
||||
APITimeoutError: If the OpenAI API request times out.
|
||||
"""
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
||||
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
||||
logging.getLogger("openai").setLevel(logging.INFO)
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||
if base_url is None
|
||||
else AsyncOpenAI(
|
||||
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||
)
|
||||
# Extract client configuration options
|
||||
client_configs = kwargs.pop("openai_client_configs", {})
|
||||
|
||||
# Create the OpenAI client
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key, base_url=base_url, client_configs=client_configs
|
||||
)
|
||||
|
||||
# Remove special kwargs that shouldn't be passed to OpenAI
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
|
||||
# Prepare messages
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
@@ -272,21 +336,32 @@ async def openai_embed(
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
client_configs: dict[str, Any] = None,
|
||||
) -> np.ndarray:
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
"""Generate embeddings for a list of texts using OpenAI's API.
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
openai_async_client = (
|
||||
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||
if base_url is None
|
||||
else AsyncOpenAI(
|
||||
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||
)
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The OpenAI embedding model to use.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If there is a connection error with the OpenAI API.
|
||||
RateLimitError: If the OpenAI API rate limit is exceeded.
|
||||
APITimeoutError: If the OpenAI API request times out.
|
||||
"""
|
||||
# Create the OpenAI client
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key, base_url=base_url, client_configs=client_configs
|
||||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
|
@@ -697,8 +697,7 @@ async def kg_query(
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
@@ -794,6 +793,38 @@ async def kg_query(
|
||||
return response
|
||||
|
||||
|
||||
async def get_keywords_from_query(
|
||||
query: str,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Retrieves high-level and low-level keywords for RAG operations.
|
||||
|
||||
This function checks if keywords are already provided in query parameters,
|
||||
and if not, extracts them from the query text using LLM.
|
||||
|
||||
Args:
|
||||
query: The user's query text
|
||||
query_param: Query parameters that may contain pre-defined keywords
|
||||
global_config: Global configuration dictionary
|
||||
hashing_kv: Optional key-value storage for caching results
|
||||
|
||||
Returns:
|
||||
A tuple containing (high_level_keywords, low_level_keywords)
|
||||
"""
|
||||
# Check if pre-defined keywords are already provided
|
||||
if query_param.hl_keywords or query_param.ll_keywords:
|
||||
return query_param.hl_keywords, query_param.ll_keywords
|
||||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
return hl_keywords, ll_keywords
|
||||
|
||||
|
||||
async def extract_keywords_only(
|
||||
text: str,
|
||||
param: QueryParam,
|
||||
@@ -934,8 +965,7 @@ async def mix_kg_vector_query(
|
||||
# 2. Execute knowledge graph and vector searches in parallel
|
||||
async def get_kg_context():
|
||||
try:
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
@@ -983,7 +1013,6 @@ async def mix_kg_vector_query(
|
||||
try:
|
||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||||
mix_topk = min(10, query_param.top_k)
|
||||
# TODO: add ids to the query
|
||||
results = await chunks_vdb.query(
|
||||
augmented_query, top_k=mix_topk, ids=query_param.ids
|
||||
)
|
||||
@@ -1581,9 +1610,7 @@ async def _get_edge_data(
|
||||
|
||||
text_units_section_list = [["id", "content", "file_path"]]
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append(
|
||||
[i, t["content"], t.get("file_path", "unknown_source")]
|
||||
)
|
||||
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context, relations_context, text_units_context
|
||||
|
||||
@@ -2017,16 +2044,13 @@ async def query_with_keywords(
|
||||
Query response or async iterator
|
||||
"""
|
||||
# Extract keywords
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
param=param,
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query=query,
|
||||
query_param=param,
|
||||
global_config=global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# Create a new string with the prompt and the keywords
|
||||
ll_keywords_str = ", ".join(ll_keywords)
|
||||
hl_keywords_str = ", ".join(hl_keywords)
|
||||
|
@@ -962,6 +962,13 @@ class TokenTracker:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def __enter__(self):
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
print(self)
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
|
Reference in New Issue
Block a user