linting errors

This commit is contained in:
omdivyatej
2025-03-25 15:20:09 +05:30
parent f87c235a4c
commit f049f2f5c4
4 changed files with 36 additions and 20 deletions

View File

@@ -9,6 +9,7 @@ WORKING_DIR = "./lightrag_demo"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
async def initialize_rag(): async def initialize_rag():
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
@@ -21,6 +22,7 @@ async def initialize_rag():
return rag return rag
def main(): def main():
# Initialize RAG instance # Initialize RAG instance
rag = asyncio.run(initialize_rag()) rag = asyncio.run(initialize_rag())
@@ -33,8 +35,7 @@ def main():
print("--- NAIVE mode ---") print("--- NAIVE mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="naive")
param=QueryParam(mode="naive")
) )
) )
@@ -42,8 +43,7 @@ def main():
print("\n--- LOCAL mode ---") print("\n--- LOCAL mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="local")
param=QueryParam(mode="local")
) )
) )
@@ -51,8 +51,7 @@ def main():
print("\n--- GLOBAL mode ---") print("\n--- GLOBAL mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="global")
param=QueryParam(mode="global")
) )
) )
@@ -60,8 +59,7 @@ def main():
print("\n--- HYBRID mode ---") print("\n--- HYBRID mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="hybrid")
param=QueryParam(mode="hybrid")
) )
) )
@@ -69,8 +67,7 @@ def main():
print("\n--- MIX mode ---") print("\n--- MIX mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="mix")
param=QueryParam(mode="mix")
) )
) )
@@ -81,10 +78,11 @@ def main():
"How does the character development reflect Victorian-era attitudes?", "How does the character development reflect Victorian-era attitudes?",
param=QueryParam( param=QueryParam(
mode="global", mode="global",
model_func=gpt_4o_complete # Override default model with more capable one model_func=gpt_4o_complete, # Override default model with more capable one
) ),
) )
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -84,7 +84,7 @@ class QueryParam:
ids: list[str] | None = None ids: list[str] | None = None
"""List of ids to filter the results.""" """List of ids to filter the results."""
model_func: Callable[..., object] | None = None model_func: Callable[..., object] | None = None
"""Optional override for the LLM model function to use for this specific query. """Optional override for the LLM model function to use for this specific query.
If provided, this will be used instead of the global model function. If provided, this will be used instead of the global model function.

View File

@@ -1338,7 +1338,7 @@ class LightRAG:
""" """
# If a custom model is provided in param, temporarily update global config # If a custom model is provided in param, temporarily update global config
global_config = asdict(self) global_config = asdict(self)
if param.mode in ["local", "global", "hybrid"]: if param.mode in ["local", "global", "hybrid"]:
response = await kg_query( response = await kg_query(
query.strip(), query.strip(),

View File

@@ -705,7 +705,11 @@ async def kg_query(
system_prompt: str | None = None, system_prompt: str | None = None,
) -> str | AsyncIterator[str]: ) -> str | AsyncIterator[str]:
# Handle cache # Handle cache
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +870,9 @@ async def extract_keywords_only(
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction # 5. Call the LLM for keyword extraction
use_model_func = param.model_func if param.model_func else global_config["llm_model_func"] use_model_func = (
param.model_func if param.model_func else global_config["llm_model_func"]
)
result = await use_model_func(kw_prompt, keyword_extraction=True) result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response # 6. Parse out JSON from the LLM response
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
3. Combining both results for comprehensive answer generation 3. Combining both results for comprehensive answer generation
""" """
# 1. Cache handling # 1. Cache handling
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash("mix", query, cache_type="query") args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query" hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1741,11 @@ async def naive_query(
system_prompt: str | None = None, system_prompt: str | None = None,
) -> str | AsyncIterator[str]: ) -> str | AsyncIterator[str]:
# Handle cache # Handle cache
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
# --------------------------- # ---------------------------
# 1) Handle potential cache for query results # 1) Handle potential cache for query results
# --------------------------- # ---------------------------
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"