Merge pull request #1167 from omdivyatej/om-pr

Feature: Dynamic LLM Selection via QueryParam for Optimized Performance
This commit is contained in:
Daniel.y
2025-03-25 18:13:44 +08:00
committed by GitHub
4 changed files with 125 additions and 8 deletions

View File

@@ -10,6 +10,7 @@ from typing import (
Literal,
TypedDict,
TypeVar,
Callable,
)
import numpy as np
from .utils import EmbeddingFunc
@@ -84,6 +85,12 @@ class QueryParam:
ids: list[str] | None = None
"""List of ids to filter the results."""
model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function.
This allows using different models for different query modes.
"""
@dataclass
class StorageNameSpace(ABC):

View File

@@ -1330,11 +1330,15 @@ class LightRAG:
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
If param.model_func is provided, it will be used instead of the global model.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query.strip(),
@@ -1343,7 +1347,7 @@ class LightRAG:
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
@@ -1353,7 +1357,7 @@ class LightRAG:
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
@@ -1366,7 +1370,7 @@ class LightRAG:
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)

View File

@@ -705,7 +705,11 @@ async def kg_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +870,9 @@ async def extract_keywords_only(
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"]
use_model_func = (
param.model_func if param.model_func else global_config["llm_model_func"]
)
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
3. Combining both results for comprehensive answer generation
"""
# 1. Cache handling
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1741,11 @@ async def naive_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
# ---------------------------
# 1) Handle potential cache for query results
# ---------------------------
use_model_func = global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"