Merge branch 'main' into add-multi-worker-support

This commit is contained in:
yangdx
2025-03-01 15:55:37 +08:00
31 changed files with 1755 additions and 1371 deletions

View File

@@ -363,14 +363,14 @@ class LightRAG:
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
meta_fields={"entity_name", "source_id", "content"},
)
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
meta_fields={"src_id", "tgt_id", "source_id", "content"},
)
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
@@ -408,16 +408,31 @@ class LightRAG:
self._storages_status = StoragesStatus.CREATED
# Initialize storages
if self.auto_manage_storages_states:
loop = always_get_an_event_loop()
loop.run_until_complete(self.initialize_storages())
self._run_async_safely(self.initialize_storages, "Storage Initialization")
def __del__(self):
# Finalize storages
if self.auto_manage_storages_states:
self._run_async_safely(self.finalize_storages, "Storage Finalization")
def _run_async_safely(self, async_func, action_name=""):
"""Safely execute an async function, avoiding event loop conflicts."""
try:
loop = always_get_an_event_loop()
loop.run_until_complete(self.finalize_storages())
if loop.is_running():
task = loop.create_task(async_func())
task.add_done_callback(
lambda t: logger.info(f"{action_name} completed!")
)
else:
loop.run_until_complete(async_func())
except RuntimeError:
logger.warning(
f"No running event loop, creating a new loop for {action_name}."
)
loop = asyncio.new_event_loop()
loop.run_until_complete(async_func())
loop.close()
async def initialize_storages(self):
"""Asynchronously initialize the storages"""
@@ -491,7 +506,7 @@ class LightRAG:
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: list[str] | None = None,
ids: str | list[str] | None = None,
) -> None:
"""Sync Insert documents with checkpoint support
@@ -500,7 +515,7 @@ class LightRAG:
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
ids: single string of the document ID or list of unique document IDs, if not provided, MD5 hash IDs will be generated
"""
loop = always_get_an_event_loop()
loop.run_until_complete(
@@ -512,7 +527,7 @@ class LightRAG:
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: list[str] | None = None,
ids: str | list[str] | None = None,
) -> None:
"""Async Insert documents with checkpoint support
@@ -528,12 +543,19 @@ class LightRAG:
split_by_character, split_by_character_only
)
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]) -> None:
def insert_custom_chunks(
self,
full_text: str,
text_chunks: list[str],
doc_id: str | list[str] | None = None,
) -> None:
loop = always_get_an_event_loop()
loop.run_until_complete(self.ainsert_custom_chunks(full_text, text_chunks))
loop.run_until_complete(
self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
)
async def ainsert_custom_chunks(
self, full_text: str, text_chunks: list[str]
self, full_text: str, text_chunks: list[str], doc_id: str | None = None
) -> None:
update_storage = False
try:
@@ -542,7 +564,10 @@ class LightRAG:
text_chunks = [self.clean_text(chunk) for chunk in text_chunks]
# Process cleaned texts
doc_key = compute_mdhash_id(full_text, prefix="doc-")
if doc_id is None:
doc_key = compute_mdhash_id(full_text, prefix="doc-")
else:
doc_key = doc_id
new_docs = {doc_key: {"content": full_text}}
_add_doc_keys = await self.full_docs.filter_keys({doc_key})
@@ -598,6 +623,8 @@ class LightRAG:
"""
if isinstance(input, str):
input = [input]
if isinstance(ids, str):
ids = [ids]
# 1. Validate ids if provided or generate MD5 hash IDs
if ids is not None:
@@ -1366,12 +1393,14 @@ class LightRAG:
logger.debug(f"Starting deletion for document {doc_id}")
doc_to_chunk_id = doc_id.replace("doc", "chunk")
# 2. Get all related chunks
chunks = await self.text_chunks.get_by_id(doc_id)
chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
if not chunks:
return
chunk_ids = list(chunks.keys())
chunk_ids = {chunks["full_doc_id"].replace("doc", "chunk")}
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
# 3. Before deleting, check the related entities and relationships for these chunks
@@ -1380,7 +1409,7 @@ class LightRAG:
entities = [
dp
for dp in self.entities_vdb.client_storage["data"]
if dp.get("source_id") == chunk_id
if chunk_id in dp.get("source_id")
]
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
@@ -1388,7 +1417,7 @@ class LightRAG:
relations = [
dp
for dp in self.relationships_vdb.client_storage["data"]
if dp.get("source_id") == chunk_id
if chunk_id in dp.get("source_id")
]
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
@@ -1499,42 +1528,71 @@ class LightRAG:
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
)
async def process_data(data_type, vdb, chunk_id):
# Check data (entities or relationships)
data_with_chunk = [
dp
for dp in vdb.client_storage["data"]
if chunk_id in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
]
data_for_vdb = {}
if data_with_chunk:
logger.warning(
f"found {len(data_with_chunk)} {data_type} still referencing chunk {chunk_id}"
)
for item in data_with_chunk:
old_sources = item["source_id"].split(GRAPH_FIELD_SEP)
new_sources = [src for src in old_sources if src != chunk_id]
if not new_sources:
logger.info(
f"{data_type} {item.get('entity_name', 'N/A')} is deleted because source_id is not exists"
)
await vdb.delete_entity(item)
else:
item["source_id"] = GRAPH_FIELD_SEP.join(new_sources)
item_id = item["__id__"]
data_for_vdb[item_id] = item.copy()
if data_type == "entities":
data_for_vdb[item_id]["content"] = data_for_vdb[
item_id
].get("content") or (
item.get("entity_name", "")
+ (item.get("description") or "")
)
else: # relationships
data_for_vdb[item_id]["content"] = data_for_vdb[
item_id
].get("content") or (
(item.get("keywords") or "")
+ (item.get("src_id") or "")
+ (item.get("tgt_id") or "")
+ (item.get("description") or "")
)
if data_for_vdb:
await vdb.upsert(data_for_vdb)
logger.info(f"Successfully updated {data_type} in vector DB")
# Add verification step
async def verify_deletion():
# Verify if the document has been deleted
if await self.full_docs.get_by_id(doc_id):
logger.error(f"Document {doc_id} still exists in full_docs")
logger.warning(f"Document {doc_id} still exists in full_docs")
# Verify if chunks have been deleted
remaining_chunks = await self.text_chunks.get_by_id(doc_id)
remaining_chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
if remaining_chunks:
logger.error(f"Found {len(remaining_chunks)} remaining chunks")
logger.warning(f"Found {len(remaining_chunks)} remaining chunks")
# Verify entities and relationships
for chunk_id in chunk_ids:
# Check entities
entities_with_chunk = [
dp
for dp in self.entities_vdb.client_storage["data"]
if chunk_id
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
]
if entities_with_chunk:
logger.error(
f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}"
)
# Check relationships
relations_with_chunk = [
dp
for dp in self.relationships_vdb.client_storage["data"]
if chunk_id
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
]
if relations_with_chunk:
logger.error(
f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}"
)
await process_data("entities", self.entities_vdb, chunk_id)
await process_data(
"relationships", self.relationships_vdb, chunk_id
)
await verify_deletion()