Merge pull request #1167 from omdivyatej/om-pr
Feature: Dynamic LLM Selection via QueryParam for Optimized Performance
This commit is contained in:
88
examples/lightrag_multi_model_all_modes_demo.py
Normal file
88
examples/lightrag_multi_model_all_modes_demo.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
|
||||||
|
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||||
|
|
||||||
|
WORKING_DIR = "./lightrag_demo"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_rag():
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
embedding_func=openai_embed,
|
||||||
|
llm_model_func=gpt_4o_mini_complete, # Default model for queries
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag.initialize_storages()
|
||||||
|
await initialize_pipeline_status()
|
||||||
|
|
||||||
|
return rag
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initialize RAG instance
|
||||||
|
rag = asyncio.run(initialize_rag())
|
||||||
|
|
||||||
|
# Load the data
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Query with naive mode (default model)
|
||||||
|
print("--- NAIVE mode ---")
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the main themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with local mode (default model)
|
||||||
|
print("\n--- LOCAL mode ---")
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the main themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with global mode (default model)
|
||||||
|
print("\n--- GLOBAL mode ---")
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the main themes in this story?", param=QueryParam(mode="global")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with hybrid mode (default model)
|
||||||
|
print("\n--- HYBRID mode ---")
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the main themes in this story?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with mix mode (default model)
|
||||||
|
print("\n--- MIX mode ---")
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the main themes in this story?", param=QueryParam(mode="mix")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with a custom model (gpt-4o) for a more complex question
|
||||||
|
print("\n--- Using custom model for complex analysis ---")
|
||||||
|
print(
|
||||||
|
rag.query(
|
||||||
|
"How does the character development reflect Victorian-era attitudes?",
|
||||||
|
param=QueryParam(
|
||||||
|
mode="global",
|
||||||
|
model_func=gpt_4o_complete, # Override default model with more capable one
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -10,6 +10,7 @@ from typing import (
|
|||||||
Literal,
|
Literal,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Callable,
|
||||||
)
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
@@ -84,6 +85,12 @@ class QueryParam:
|
|||||||
ids: list[str] | None = None
|
ids: list[str] | None = None
|
||||||
"""List of ids to filter the results."""
|
"""List of ids to filter the results."""
|
||||||
|
|
||||||
|
model_func: Callable[..., object] | None = None
|
||||||
|
"""Optional override for the LLM model function to use for this specific query.
|
||||||
|
If provided, this will be used instead of the global model function.
|
||||||
|
This allows using different models for different query modes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageNameSpace(ABC):
|
class StorageNameSpace(ABC):
|
||||||
|
@@ -1330,11 +1330,15 @@ class LightRAG:
|
|||||||
Args:
|
Args:
|
||||||
query (str): The query to be executed.
|
query (str): The query to be executed.
|
||||||
param (QueryParam): Configuration parameters for query execution.
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
If param.model_func is provided, it will be used instead of the global model.
|
||||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The result of the query execution.
|
str: The result of the query execution.
|
||||||
"""
|
"""
|
||||||
|
# If a custom model is provided in param, temporarily update global config
|
||||||
|
global_config = asdict(self)
|
||||||
|
|
||||||
if param.mode in ["local", "global", "hybrid"]:
|
if param.mode in ["local", "global", "hybrid"]:
|
||||||
response = await kg_query(
|
response = await kg_query(
|
||||||
query.strip(),
|
query.strip(),
|
||||||
@@ -1343,7 +1347,7 @@ class LightRAG:
|
|||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
@@ -1353,7 +1357,7 @@ class LightRAG:
|
|||||||
self.chunks_vdb,
|
self.chunks_vdb,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
@@ -1366,7 +1370,7 @@ class LightRAG:
|
|||||||
self.chunks_vdb,
|
self.chunks_vdb,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
|
@@ -705,7 +705,11 @@ async def kg_query(
|
|||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
) -> str | AsyncIterator[str]:
|
) -> str | AsyncIterator[str]:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = (
|
||||||
|
query_param.model_func
|
||||||
|
if query_param.model_func
|
||||||
|
else global_config["llm_model_func"]
|
||||||
|
)
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
@@ -866,7 +870,9 @@ async def extract_keywords_only(
|
|||||||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
# 5. Call the LLM for keyword extraction
|
# 5. Call the LLM for keyword extraction
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = (
|
||||||
|
param.model_func if param.model_func else global_config["llm_model_func"]
|
||||||
|
)
|
||||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||||
|
|
||||||
# 6. Parse out JSON from the LLM response
|
# 6. Parse out JSON from the LLM response
|
||||||
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
|
|||||||
3. Combining both results for comprehensive answer generation
|
3. Combining both results for comprehensive answer generation
|
||||||
"""
|
"""
|
||||||
# 1. Cache handling
|
# 1. Cache handling
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = (
|
||||||
|
query_param.model_func
|
||||||
|
if query_param.model_func
|
||||||
|
else global_config["llm_model_func"]
|
||||||
|
)
|
||||||
args_hash = compute_args_hash("mix", query, cache_type="query")
|
args_hash = compute_args_hash("mix", query, cache_type="query")
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, "mix", cache_type="query"
|
hashing_kv, args_hash, query, "mix", cache_type="query"
|
||||||
@@ -1731,7 +1741,11 @@ async def naive_query(
|
|||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
) -> str | AsyncIterator[str]:
|
) -> str | AsyncIterator[str]:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = (
|
||||||
|
query_param.model_func
|
||||||
|
if query_param.model_func
|
||||||
|
else global_config["llm_model_func"]
|
||||||
|
)
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
|
|||||||
# ---------------------------
|
# ---------------------------
|
||||||
# 1) Handle potential cache for query results
|
# 1) Handle potential cache for query results
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = (
|
||||||
|
query_param.model_func
|
||||||
|
if query_param.model_func
|
||||||
|
else global_config["llm_model_func"]
|
||||||
|
)
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
|
Reference in New Issue
Block a user