feat: fix delete by document id
This commit is contained in:
@@ -359,14 +359,14 @@ class LightRAG:
|
|||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"entity_name"},
|
meta_fields={"entity_name", "source_id", "content"},
|
||||||
)
|
)
|
||||||
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"src_id", "tgt_id"},
|
meta_fields={"src_id", "tgt_id", "source_id", "content"},
|
||||||
)
|
)
|
||||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
@@ -1287,12 +1287,14 @@ class LightRAG:
|
|||||||
|
|
||||||
logger.debug(f"Starting deletion for document {doc_id}")
|
logger.debug(f"Starting deletion for document {doc_id}")
|
||||||
|
|
||||||
|
doc_to_chunk_id = doc_id.replace("doc", "chunk")
|
||||||
|
|
||||||
# 2. Get all related chunks
|
# 2. Get all related chunks
|
||||||
chunks = await self.text_chunks.get_by_id(doc_id)
|
chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return
|
return
|
||||||
|
|
||||||
chunk_ids = list(chunks.keys())
|
chunk_ids = {chunks["full_doc_id"].replace("doc", "chunk")}
|
||||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||||
|
|
||||||
# 3. Before deleting, check the related entities and relationships for these chunks
|
# 3. Before deleting, check the related entities and relationships for these chunks
|
||||||
@@ -1301,7 +1303,7 @@ class LightRAG:
|
|||||||
entities = [
|
entities = [
|
||||||
dp
|
dp
|
||||||
for dp in self.entities_vdb.client_storage["data"]
|
for dp in self.entities_vdb.client_storage["data"]
|
||||||
if dp.get("source_id") == chunk_id
|
if chunk_id in dp.get("source_id")
|
||||||
]
|
]
|
||||||
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
|
logger.debug(f"Chunk {chunk_id} has {len(entities)} related entities")
|
||||||
|
|
||||||
@@ -1309,7 +1311,7 @@ class LightRAG:
|
|||||||
relations = [
|
relations = [
|
||||||
dp
|
dp
|
||||||
for dp in self.relationships_vdb.client_storage["data"]
|
for dp in self.relationships_vdb.client_storage["data"]
|
||||||
if dp.get("source_id") == chunk_id
|
if chunk_id in dp.get("source_id")
|
||||||
]
|
]
|
||||||
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
|
logger.debug(f"Chunk {chunk_id} has {len(relations)} related relations")
|
||||||
|
|
||||||
@@ -1420,41 +1422,70 @@ class LightRAG:
|
|||||||
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
|
f"Updated {len(entities_to_update)} entities and {len(relationships_to_update)} relationships."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def process_data(data_type, vdb, chunk_id):
|
||||||
|
# Check data (entities or relationships)
|
||||||
|
data_with_chunk = [
|
||||||
|
dp
|
||||||
|
for dp in vdb.client_storage["data"]
|
||||||
|
if chunk_id in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
||||||
|
]
|
||||||
|
|
||||||
|
data_for_vdb = {}
|
||||||
|
if data_with_chunk:
|
||||||
|
logger.warning(
|
||||||
|
f"found {len(data_with_chunk)} {data_type} still referencing chunk {chunk_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in data_with_chunk:
|
||||||
|
old_sources = item["source_id"].split(GRAPH_FIELD_SEP)
|
||||||
|
new_sources = [src for src in old_sources if src != chunk_id]
|
||||||
|
|
||||||
|
if not new_sources:
|
||||||
|
logger.info(
|
||||||
|
f"{data_type} {item.get('entity_name', 'N/A')} is deleted because source_id is not exists"
|
||||||
|
)
|
||||||
|
await vdb.delete_entity(item)
|
||||||
|
else:
|
||||||
|
item["source_id"] = GRAPH_FIELD_SEP.join(new_sources)
|
||||||
|
item_id = item["__id__"]
|
||||||
|
data_for_vdb[item_id] = item.copy()
|
||||||
|
if data_type == "entities":
|
||||||
|
data_for_vdb[item_id]["content"] = data_for_vdb[
|
||||||
|
item_id
|
||||||
|
].get("content") or (
|
||||||
|
item.get("entity_name", "")
|
||||||
|
+ (item.get("description") or "")
|
||||||
|
)
|
||||||
|
else: # relationships
|
||||||
|
data_for_vdb[item_id]["content"] = data_for_vdb[
|
||||||
|
item_id
|
||||||
|
].get("content") or (
|
||||||
|
(item.get("keywords") or "")
|
||||||
|
+ (item.get("src_id") or "")
|
||||||
|
+ (item.get("tgt_id") or "")
|
||||||
|
+ (item.get("description") or "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_for_vdb:
|
||||||
|
await vdb.upsert(data_for_vdb)
|
||||||
|
logger.info(f"Successfully updated {data_type} in vector DB")
|
||||||
|
|
||||||
# Add verification step
|
# Add verification step
|
||||||
async def verify_deletion():
|
async def verify_deletion():
|
||||||
# Verify if the document has been deleted
|
# Verify if the document has been deleted
|
||||||
if await self.full_docs.get_by_id(doc_id):
|
if await self.full_docs.get_by_id(doc_id):
|
||||||
logger.error(f"Document {doc_id} still exists in full_docs")
|
logger.warning(f"Document {doc_id} still exists in full_docs")
|
||||||
|
|
||||||
# Verify if chunks have been deleted
|
# Verify if chunks have been deleted
|
||||||
remaining_chunks = await self.text_chunks.get_by_id(doc_id)
|
remaining_chunks = await self.text_chunks.get_by_id(doc_to_chunk_id)
|
||||||
if remaining_chunks:
|
if remaining_chunks:
|
||||||
logger.error(f"Found {len(remaining_chunks)} remaining chunks")
|
logger.warning(f"Found {len(remaining_chunks)} remaining chunks")
|
||||||
|
|
||||||
# Verify entities and relationships
|
# Verify entities and relationships
|
||||||
for chunk_id in chunk_ids:
|
for chunk_id in chunk_ids:
|
||||||
# Check entities
|
await process_data("entities", self.entities_vdb, chunk_id)
|
||||||
entities_with_chunk = [
|
await process_data(
|
||||||
dp
|
"relationships", self.relationships_vdb, chunk_id
|
||||||
for dp in self.entities_vdb.client_storage["data"]
|
|
||||||
if chunk_id
|
|
||||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
|
||||||
]
|
|
||||||
if entities_with_chunk:
|
|
||||||
logger.error(
|
|
||||||
f"Found {len(entities_with_chunk)} entities still referencing chunk {chunk_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check relationships
|
|
||||||
relations_with_chunk = [
|
|
||||||
dp
|
|
||||||
for dp in self.relationships_vdb.client_storage["data"]
|
|
||||||
if chunk_id
|
|
||||||
in (dp.get("source_id") or "").split(GRAPH_FIELD_SEP)
|
|
||||||
]
|
|
||||||
if relations_with_chunk:
|
|
||||||
logger.error(
|
|
||||||
f"Found {len(relations_with_chunk)} relations still referencing chunk {chunk_id}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await verify_deletion()
|
await verify_deletion()
|
||||||
|
@@ -323,6 +323,7 @@ async def _merge_edges_then_upsert(
|
|||||||
tgt_id=tgt_id,
|
tgt_id=tgt_id,
|
||||||
description=description,
|
description=description,
|
||||||
keywords=keywords,
|
keywords=keywords,
|
||||||
|
source_id=source_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return edge_data
|
return edge_data
|
||||||
@@ -548,6 +549,7 @@ async def extract_entities(
|
|||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
"content": dp["entity_name"] + dp["description"],
|
"content": dp["entity_name"] + dp["description"],
|
||||||
"entity_name": dp["entity_name"],
|
"entity_name": dp["entity_name"],
|
||||||
|
"source_id": dp["source_id"],
|
||||||
}
|
}
|
||||||
for dp in all_entities_data
|
for dp in all_entities_data
|
||||||
}
|
}
|
||||||
@@ -558,6 +560,7 @@ async def extract_entities(
|
|||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
|
"source_id": dp["source_id"],
|
||||||
"content": dp["keywords"]
|
"content": dp["keywords"]
|
||||||
+ dp["src_id"]
|
+ dp["src_id"]
|
||||||
+ dp["tgt_id"]
|
+ dp["tgt_id"]
|
||||||
@@ -1113,7 +1116,7 @@ async def _get_node_data(
|
|||||||
len_node_datas = len(node_datas)
|
len_node_datas = len(node_datas)
|
||||||
node_datas = truncate_list_by_token_size(
|
node_datas = truncate_list_by_token_size(
|
||||||
node_datas,
|
node_datas,
|
||||||
key=lambda x: x["description"],
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_local_context,
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -1296,7 +1299,7 @@ async def _find_most_related_edges_from_entities(
|
|||||||
)
|
)
|
||||||
all_edges_data = truncate_list_by_token_size(
|
all_edges_data = truncate_list_by_token_size(
|
||||||
all_edges_data,
|
all_edges_data,
|
||||||
key=lambda x: x["description"],
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_global_context,
|
max_token_size=query_param.max_token_for_global_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1350,7 +1353,7 @@ async def _get_edge_data(
|
|||||||
)
|
)
|
||||||
edge_datas = truncate_list_by_token_size(
|
edge_datas = truncate_list_by_token_size(
|
||||||
edge_datas,
|
edge_datas,
|
||||||
key=lambda x: x["description"],
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_global_context,
|
max_token_size=query_param.max_token_for_global_context,
|
||||||
)
|
)
|
||||||
use_entities, use_text_units = await asyncio.gather(
|
use_entities, use_text_units = await asyncio.gather(
|
||||||
@@ -1454,7 +1457,7 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
len_node_datas = len(node_datas)
|
len_node_datas = len(node_datas)
|
||||||
node_datas = truncate_list_by_token_size(
|
node_datas = truncate_list_by_token_size(
|
||||||
node_datas,
|
node_datas,
|
||||||
key=lambda x: x["description"],
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_local_context,
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
Reference in New Issue
Block a user