Refactor embedding functions and add async query limit

- Separate insert/query embedding funcs
- Add query-specific async limit
- Update storage classes to use new funcs
- Protect vector DB save with lock
- Improve config handling for thresholds
This commit is contained in:
yangdx
2025-01-31 15:00:56 +08:00
parent 54b68074a1
commit 21481dba8f
2 changed files with 17 additions and 9 deletions

View File

@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self):
self._client.save()
# Protect file write operation
async with self._save_lock:
self._client.save()

View File

@@ -154,6 +154,7 @@ class LightRAG:
embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
embedding_func_max_async_query: int = 4
# LLM
llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
@@ -195,8 +196,11 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
# Init embedding functions with separate instances for insert and query
self.insert_embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
self.query_embedding_func = limit_async_func_call(self.embedding_func_max_async_query)(
self.embedding_func
)
@@ -238,15 +242,15 @@ class LightRAG:
####
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs",
embedding_func=self.embedding_func,
embedding_func=self.insert_embedding_func,
)
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks",
embedding_func=self.embedding_func,
embedding_func=self.insert_embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation",
embedding_func=self.embedding_func,
embedding_func=self.insert_embedding_func,
)
####
# add embedding func by walter over
@@ -254,17 +258,17 @@ class LightRAG:
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
embedding_func=self.embedding_func,
embedding_func=self.query_embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
embedding_func=self.embedding_func,
embedding_func=self.query_embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
embedding_func=self.embedding_func,
embedding_func=self.query_embedding_func,
)
if self.llm_response_cache and hasattr(