Merge pull request #1167 from omdivyatej/om-pr

Feature: Dynamic LLM Selection via QueryParam for Optimized Performance
This commit is contained in:
Daniel.y
2025-03-25 18:13:44 +08:00
committed by GitHub
4 changed files with 125 additions and 8 deletions

View File

@@ -0,0 +1,88 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
from lightrag.kg.shared_storage import initialize_pipeline_status
WORKING_DIR = "./lightrag_demo"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete, # Default model for queries
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
# Load the data
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Query with naive mode (default model)
print("--- NAIVE mode ---")
print(
rag.query(
"What are the main themes in this story?", param=QueryParam(mode="naive")
)
)
# Query with local mode (default model)
print("\n--- LOCAL mode ---")
print(
rag.query(
"What are the main themes in this story?", param=QueryParam(mode="local")
)
)
# Query with global mode (default model)
print("\n--- GLOBAL mode ---")
print(
rag.query(
"What are the main themes in this story?", param=QueryParam(mode="global")
)
)
# Query with hybrid mode (default model)
print("\n--- HYBRID mode ---")
print(
rag.query(
"What are the main themes in this story?", param=QueryParam(mode="hybrid")
)
)
# Query with mix mode (default model)
print("\n--- MIX mode ---")
print(
rag.query(
"What are the main themes in this story?", param=QueryParam(mode="mix")
)
)
# Query with a custom model (gpt-4o) for a more complex question
print("\n--- Using custom model for complex analysis ---")
print(
rag.query(
"How does the character development reflect Victorian-era attitudes?",
param=QueryParam(
mode="global",
model_func=gpt_4o_complete, # Override default model with more capable one
),
)
)
if __name__ == "__main__":
main()

View File

@@ -10,6 +10,7 @@ from typing import (
Literal,
TypedDict,
TypeVar,
Callable,
)
import numpy as np
from .utils import EmbeddingFunc
@@ -84,6 +85,12 @@ class QueryParam:
ids: list[str] | None = None
"""List of ids to filter the results."""
model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function.
This allows using different models for different query modes.
"""
@dataclass
class StorageNameSpace(ABC):

View File

@@ -1330,11 +1330,15 @@ class LightRAG:
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
If param.model_func is provided, it will be used instead of the global model.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query.strip(),
@@ -1343,7 +1347,7 @@ class LightRAG:
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
@@ -1353,7 +1357,7 @@ class LightRAG:
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
@@ -1366,7 +1370,7 @@ class LightRAG:
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)

View File

@@ -705,7 +705,11 @@ async def kg_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +870,9 @@ async def extract_keywords_only(
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"]
use_model_func = (
param.model_func if param.model_func else global_config["llm_model_func"]
)
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
3. Combining both results for comprehensive answer generation
"""
# 1. Cache handling
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1741,11 @@ async def naive_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
# ---------------------------
# 1) Handle potential cache for query results
# ---------------------------
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"