linting errors

This commit is contained in:
omdivyatej
2025-03-25 15:20:09 +05:30
parent f87c235a4c
commit f049f2f5c4
4 changed files with 36 additions and 20 deletions

View File

@@ -9,6 +9,7 @@ WORKING_DIR = "./lightrag_demo"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
async def initialize_rag(): async def initialize_rag():
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
@@ -21,6 +22,7 @@ async def initialize_rag():
return rag return rag
def main(): def main():
# Initialize RAG instance # Initialize RAG instance
rag = asyncio.run(initialize_rag()) rag = asyncio.run(initialize_rag())
@@ -33,8 +35,7 @@ def main():
print("--- NAIVE mode ---") print("--- NAIVE mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="naive")
param=QueryParam(mode="naive")
) )
) )
@@ -42,8 +43,7 @@ def main():
print("\n--- LOCAL mode ---") print("\n--- LOCAL mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="local")
param=QueryParam(mode="local")
) )
) )
@@ -51,8 +51,7 @@ def main():
print("\n--- GLOBAL mode ---") print("\n--- GLOBAL mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="global")
param=QueryParam(mode="global")
) )
) )
@@ -60,8 +59,7 @@ def main():
print("\n--- HYBRID mode ---") print("\n--- HYBRID mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="hybrid")
param=QueryParam(mode="hybrid")
) )
) )
@@ -69,8 +67,7 @@ def main():
print("\n--- MIX mode ---") print("\n--- MIX mode ---")
print( print(
rag.query( rag.query(
"What are the main themes in this story?", "What are the main themes in this story?", param=QueryParam(mode="mix")
param=QueryParam(mode="mix")
) )
) )
@@ -81,10 +78,11 @@ def main():
"How does the character development reflect Victorian-era attitudes?", "How does the character development reflect Victorian-era attitudes?",
param=QueryParam( param=QueryParam(
mode="global", mode="global",
model_func=gpt_4o_complete # Override default model with more capable one model_func=gpt_4o_complete, # Override default model with more capable one
) ),
) )
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -705,7 +705,11 @@ async def kg_query(
system_prompt: str | None = None, system_prompt: str | None = None,
) -> str | AsyncIterator[str]: ) -> str | AsyncIterator[str]:
# Handle cache # Handle cache
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +870,9 @@ async def extract_keywords_only(
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction # 5. Call the LLM for keyword extraction
use_model_func = param.model_func if param.model_func else global_config["llm_model_func"] use_model_func = (
param.model_func if param.model_func else global_config["llm_model_func"]
)
result = await use_model_func(kw_prompt, keyword_extraction=True) result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response # 6. Parse out JSON from the LLM response
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
3. Combining both results for comprehensive answer generation 3. Combining both results for comprehensive answer generation
""" """
# 1. Cache handling # 1. Cache handling
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash("mix", query, cache_type="query") args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query" hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1741,11 @@ async def naive_query(
system_prompt: str | None = None, system_prompt: str | None = None,
) -> str | AsyncIterator[str]: ) -> str | AsyncIterator[str]:
# Handle cache # Handle cache
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
# --------------------------- # ---------------------------
# 1) Handle potential cache for query results # 1) Handle potential cache for query results
# --------------------------- # ---------------------------
use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"] use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query") args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache( cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query" hashing_kv, args_hash, query, query_param.mode, cache_type="query"