specify LLM for query

This commit is contained in:
omdivyatej
2025-03-23 21:33:49 +05:30
parent f8ba98c1ff
commit 3522da1b21
4 changed files with 112 additions and 8 deletions

View File

@@ -0,0 +1,93 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import setup_logger
setup_logger("lightrag", level="INFO")
WORKING_DIR = "./all_modes_demo"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def initialize_rag():
# Initialize LightRAG with a base model (gpt-4o-mini)
rag = LightRAG(
working_dir=WORKING_DIR,
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete, # Default model for most queries
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
# Load the data
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Example query
query_text = "What are the main themes in this story?"
# Demonstrate using default model (gpt-4o-mini) for all modes
print("\n===== Default Model (gpt-4o-mini) =====")
for mode in ["local", "global", "hybrid", "naive", "mix"]:
print(f"\n--- {mode.upper()} mode with default model ---")
response = rag.query(
query_text,
param=QueryParam(mode=mode)
)
print(response)
# Demonstrate using custom model (gpt-4o) for all modes
print("\n===== Custom Model (gpt-4o) =====")
for mode in ["local", "global", "hybrid", "naive", "mix"]:
print(f"\n--- {mode.upper()} mode with custom model ---")
response = rag.query(
query_text,
param=QueryParam(
mode=mode,
model_func=gpt_4o_complete # Override with more capable model
)
)
print(response)
# Mixed approach - use different models for different modes
print("\n===== Strategic Model Selection =====")
# Complex analytical question
complex_query = "How does the character development in the story reflect Victorian-era social values?"
# Use default model for simpler modes
print("\n--- NAIVE mode with default model (suitable for simple retrieval) ---")
response1 = rag.query(
complex_query,
param=QueryParam(mode="naive") # Use default model for basic retrieval
)
print(response1)
# Use more capable model for complex modes
print("\n--- HYBRID mode with more capable model (for complex analysis) ---")
response2 = rag.query(
complex_query,
param=QueryParam(
mode="hybrid",
model_func=gpt_4o_complete # Use more capable model for complex analysis
)
)
print(response2)
if __name__ == "__main__":
main()

View File

@@ -10,6 +10,7 @@ from typing import (
Literal,
TypedDict,
TypeVar,
Callable,
)
import numpy as np
from .utils import EmbeddingFunc
@@ -83,6 +84,12 @@ class QueryParam:
ids: list[str] | None = None
"""List of ids to filter the results."""
model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function.
This allows using different models for different query modes.
"""
@dataclass

View File

@@ -1330,11 +1330,15 @@ class LightRAG:
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
If param.model_func is provided, it will be used instead of the global model.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query.strip(),
@@ -1343,7 +1347,7 @@ class LightRAG:
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
@@ -1353,7 +1357,7 @@ class LightRAG:
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
@@ -1366,7 +1370,7 @@ class LightRAG:
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)

View File

@@ -705,7 +705,7 @@ async def kg_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +866,7 @@ async def extract_keywords_only(
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"]
use_model_func = param.model_func if param.model_func else global_config["llm_model_func"]
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response
@@ -926,7 +926,7 @@ async def mix_kg_vector_query(
3. Combining both results for comprehensive answer generation
"""
# 1. Cache handling
use_model_func = global_config["llm_model_func"]
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1731,7 @@ async def naive_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1850,7 @@ async def kg_query_with_keywords(
# ---------------------------
# 1) Handle potential cache for query results
# ---------------------------
use_model_func = global_config["llm_model_func"]
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"