Fix linting
This commit is contained in:
@@ -197,12 +197,12 @@ class LightRAG:
|
|||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
# Init embedding functions with separate instances for insert and query
|
# Init embedding functions with separate instances for insert and query
|
||||||
self.insert_embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
self.insert_embedding_func = limit_async_func_call(
|
||||||
self.embedding_func
|
self.embedding_func_max_async
|
||||||
)
|
)(self.embedding_func)
|
||||||
self.query_embedding_func = limit_async_func_call(self.embedding_func_max_async_query)(
|
self.query_embedding_func = limit_async_func_call(
|
||||||
self.embedding_func
|
self.embedding_func_max_async_query
|
||||||
)
|
)(self.embedding_func)
|
||||||
|
|
||||||
# Initialize all storages
|
# Initialize all storages
|
||||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
||||||
|
@@ -363,7 +363,11 @@ async def extract_entities(
|
|||||||
|
|
||||||
# create a llm function with new_config for handle_cache
|
# create a llm function with new_config for handle_cache
|
||||||
async def custom_llm(
|
async def custom_llm(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
keyword_extraction=False,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
# 合并 new_config 和其他 kwargs,保证其他参数不被覆盖
|
# 合并 new_config 和其他 kwargs,保证其他参数不被覆盖
|
||||||
merged_config = {**kwargs, **new_config}
|
merged_config = {**kwargs, **new_config}
|
||||||
@@ -388,7 +392,7 @@ async def extract_entities(
|
|||||||
_prompt,
|
_prompt,
|
||||||
"default",
|
"default",
|
||||||
cache_type="default",
|
cache_type="default",
|
||||||
llm=custom_llm
|
llm=custom_llm,
|
||||||
)
|
)
|
||||||
if cached_return:
|
if cached_return:
|
||||||
logger.debug(f"Found cache for {arg_hash}")
|
logger.debug(f"Found cache for {arg_hash}")
|
||||||
|
@@ -491,7 +491,9 @@ def dequantize_embedding(
|
|||||||
return (quantized * scale + min_val).astype(np.float32)
|
return (quantized * scale + min_val).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None):
|
async def handle_cache(
|
||||||
|
hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None
|
||||||
|
):
|
||||||
"""Generic cache handling function"""
|
"""Generic cache handling function"""
|
||||||
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
@@ -528,7 +530,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
|||||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
mode=mode,
|
mode=mode,
|
||||||
use_llm_check=use_llm_check,
|
use_llm_check=use_llm_check,
|
||||||
llm_func=llm if (use_llm_check and llm is not None) else (llm_model_func if use_llm_check else None),
|
llm_func=llm
|
||||||
|
if (use_llm_check and llm is not None)
|
||||||
|
else (llm_model_func if use_llm_check else None),
|
||||||
original_prompt=prompt if use_llm_check else None,
|
original_prompt=prompt if use_llm_check else None,
|
||||||
cache_type=cache_type,
|
cache_type=cache_type,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user