linting errors

This commit is contained in:
omdivyatej
2025-03-25 15:20:09 +05:30
parent f87c235a4c
commit f049f2f5c4
4 changed files with 36 additions and 20 deletions

View File

@@ -84,7 +84,7 @@ class QueryParam:
ids: list[str] | None = None
"""List of ids to filter the results."""
model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function.

View File

@@ -1338,7 +1338,7 @@ class LightRAG:
"""
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query.strip(),

View File

@@ -705,7 +705,11 @@ async def kg_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +870,9 @@ async def extract_keywords_only(
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
use_model_func = param.model_func if param.model_func else global_config["llm_model_func"]
use_model_func = (
param.model_func if param.model_func else global_config["llm_model_func"]
)
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
3. Combining both results for comprehensive answer generation
"""
# 1. Cache handling
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1741,11 @@ async def naive_query(
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
# ---------------------------
# 1) Handle potential cache for query results
# ---------------------------
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"