Fix linting

This commit is contained in:
yangdx
2025-04-14 12:08:56 +08:00
parent 40240bc79e
commit 5c1d4201f9
3 changed files with 136 additions and 108 deletions

View File

@@ -1425,7 +1425,6 @@ class LightRAG:
async def _query_done(self): async def _query_done(self):
await self.llm_response_cache.index_done_callback() await self.llm_response_cache.index_done_callback()
async def aclear_cache(self, modes: list[str] | None = None) -> None: async def aclear_cache(self, modes: list[str] | None = None) -> None:
"""Clear cache data from the LLM response cache storage. """Clear cache data from the LLM response cache storage.
@@ -1479,7 +1478,6 @@ class LightRAG:
"""Synchronous version of aclear_cache.""" """Synchronous version of aclear_cache."""
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes)) return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
@@ -1748,7 +1746,6 @@ class LightRAG:
except Exception as e: except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}") logger.error(f"Error while deleting document {doc_id}: {e}")
async def adelete_by_entity(self, entity_name: str) -> None: async def adelete_by_entity(self, entity_name: str) -> None:
"""Asynchronously delete an entity and all its relationships. """Asynchronously delete an entity and all its relationships.
@@ -1756,11 +1753,12 @@ class LightRAG:
entity_name: Name of the entity to delete entity_name: Name of the entity to delete
""" """
from .utils_graph import adelete_by_entity from .utils_graph import adelete_by_entity
return await adelete_by_entity( return await adelete_by_entity(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
entity_name entity_name,
) )
def delete_by_entity(self, entity_name: str) -> None: def delete_by_entity(self, entity_name: str) -> None:
@@ -1775,16 +1773,19 @@ class LightRAG:
target_entity: Name of the target entity target_entity: Name of the target entity
""" """
from .utils_graph import adelete_by_relation from .utils_graph import adelete_by_relation
return await adelete_by_relation( return await adelete_by_relation(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.relationships_vdb, self.relationships_vdb,
source_entity, source_entity,
target_entity target_entity,
) )
def delete_by_relation(self, source_entity: str, target_entity: str) -> None: def delete_by_relation(self, source_entity: str, target_entity: str) -> None:
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) return loop.run_until_complete(
self.adelete_by_relation(source_entity, target_entity)
)
async def get_processing_status(self) -> dict[str, int]: async def get_processing_status(self) -> dict[str, int]:
"""Get current document processing status counts """Get current document processing status counts
@@ -1799,11 +1800,12 @@ class LightRAG:
) -> dict[str, str | None | dict[str, str]]: ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity""" """Get detailed information of an entity"""
from .utils_graph import get_entity_info from .utils_graph import get_entity_info
return await get_entity_info( return await get_entity_info(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
entity_name, entity_name,
include_vector_data include_vector_data,
) )
async def get_relation_info( async def get_relation_info(
@@ -1811,12 +1813,13 @@ class LightRAG:
) -> dict[str, str | None | dict[str, str]]: ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of a relationship""" """Get detailed information of a relationship"""
from .utils_graph import get_relation_info from .utils_graph import get_relation_info
return await get_relation_info( return await get_relation_info(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.relationships_vdb, self.relationships_vdb,
src_entity, src_entity,
tgt_entity, tgt_entity,
include_vector_data include_vector_data,
) )
async def aedit_entity( async def aedit_entity(
@@ -1835,13 +1838,14 @@ class LightRAG:
Dictionary containing updated entity information Dictionary containing updated entity information
""" """
from .utils_graph import aedit_entity from .utils_graph import aedit_entity
return await aedit_entity( return await aedit_entity(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
entity_name, entity_name,
updated_data, updated_data,
allow_rename allow_rename,
) )
def edit_entity( def edit_entity(
@@ -1868,13 +1872,14 @@ class LightRAG:
Dictionary containing updated relation information Dictionary containing updated relation information
""" """
from .utils_graph import aedit_relation from .utils_graph import aedit_relation
return await aedit_relation( return await aedit_relation(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
source_entity, source_entity,
target_entity, target_entity,
updated_data updated_data,
) )
def edit_relation( def edit_relation(
@@ -1900,12 +1905,13 @@ class LightRAG:
Dictionary containing created entity information Dictionary containing created entity information
""" """
from .utils_graph import acreate_entity from .utils_graph import acreate_entity
return await acreate_entity( return await acreate_entity(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
entity_name, entity_name,
entity_data entity_data,
) )
def create_entity( def create_entity(
@@ -1930,13 +1936,14 @@ class LightRAG:
Dictionary containing created relation information Dictionary containing created relation information
""" """
from .utils_graph import acreate_relation from .utils_graph import acreate_relation
return await acreate_relation( return await acreate_relation(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
source_entity, source_entity,
target_entity, target_entity,
relation_data relation_data,
) )
def create_relation( def create_relation(
@@ -1975,6 +1982,7 @@ class LightRAG:
Dictionary containing the merged entity information Dictionary containing the merged entity information
""" """
from .utils_graph import amerge_entities from .utils_graph import amerge_entities
return await amerge_entities( return await amerge_entities(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
@@ -1982,7 +1990,7 @@ class LightRAG:
source_entities, source_entities,
target_entity, target_entity,
merge_strategy, merge_strategy,
target_entity_data target_entity_data,
) )
def merge_entities( def merge_entities(
@@ -2018,14 +2026,14 @@ class LightRAG:
include_vector_data: Whether to include data from the vector database. include_vector_data: Whether to include data from the vector database.
""" """
from .utils import aexport_data as utils_aexport_data from .utils import aexport_data as utils_aexport_data
await utils_aexport_data( await utils_aexport_data(
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
output_path, output_path,
file_format, file_format,
include_vector_data include_vector_data,
) )
def export_data( def export_data(

View File

@@ -903,7 +903,7 @@ async def aexport_data(
) -> None: ) -> None:
""" """
Asynchronously exports all entities, relations, and relationships to various formats. Asynchronously exports all entities, relations, and relationships to various formats.
Args: Args:
chunk_entity_relation_graph: Graph storage instance for entities and relations chunk_entity_relation_graph: Graph storage instance for entities and relations
entities_vdb: Vector database storage for entities entities_vdb: Vector database storage for entities
@@ -927,22 +927,24 @@ async def aexport_data(
# Get entity information from graph # Get entity information from graph
node_data = await chunk_entity_relation_graph.get_node(entity_name) node_data = await chunk_entity_relation_graph.get_node(entity_name)
source_id = node_data.get("source_id") if node_data else None source_id = node_data.get("source_id") if node_data else None
entity_info = { entity_info = {
"graph_data": node_data, "graph_data": node_data,
"source_id": source_id, "source_id": source_id,
} }
# Optional: Get vector database information # Optional: Get vector database information
if include_vector_data: if include_vector_data:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
vector_data = await entities_vdb.get_by_id(entity_id) vector_data = await entities_vdb.get_by_id(entity_id)
entity_info["vector_data"] = vector_data entity_info["vector_data"] = vector_data
entity_row = { entity_row = {
"entity_name": entity_name, "entity_name": entity_name,
"source_id": source_id, "source_id": source_id,
"graph_data": str(entity_info["graph_data"]), # Convert to string to ensure compatibility "graph_data": str(
entity_info["graph_data"]
), # Convert to string to ensure compatibility
} }
if include_vector_data and "vector_data" in entity_info: if include_vector_data and "vector_data" in entity_info:
entity_row["vector_data"] = str(entity_info["vector_data"]) entity_row["vector_data"] = str(entity_info["vector_data"])
@@ -963,18 +965,18 @@ async def aexport_data(
src_entity, tgt_entity src_entity, tgt_entity
) )
source_id = edge_data.get("source_id") if edge_data else None source_id = edge_data.get("source_id") if edge_data else None
relation_info = { relation_info = {
"graph_data": edge_data, "graph_data": edge_data,
"source_id": source_id, "source_id": source_id,
} }
# Optional: Get vector database information # Optional: Get vector database information
if include_vector_data: if include_vector_data:
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-") rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
vector_data = await relationships_vdb.get_by_id(rel_id) vector_data = await relationships_vdb.get_by_id(rel_id)
relation_info["vector_data"] = vector_data relation_info["vector_data"] = vector_data
relation_row = { relation_row = {
"src_entity": src_entity, "src_entity": src_entity,
"tgt_entity": tgt_entity, "tgt_entity": tgt_entity,
@@ -1010,9 +1012,7 @@ async def aexport_data(
# Relations # Relations
if relations_data: if relations_data:
csvfile.write("# RELATIONS\n") csvfile.write("# RELATIONS\n")
writer = csv.DictWriter( writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys())
csvfile, fieldnames=relations_data[0].keys()
)
writer.writeheader() writer.writeheader()
writer.writerows(relations_data) writer.writerows(relations_data)
csvfile.write("\n\n") csvfile.write("\n\n")
@@ -1029,17 +1029,13 @@ async def aexport_data(
elif file_format == "excel": elif file_format == "excel":
# Excel export # Excel export
import pandas as pd import pandas as pd
entities_df = ( entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
)
relations_df = ( relations_df = (
pd.DataFrame(relations_data) if relations_data else pd.DataFrame() pd.DataFrame(relations_data) if relations_data else pd.DataFrame()
) )
relationships_df = ( relationships_df = (
pd.DataFrame(relationships_data) pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame()
if relationships_data
else pd.DataFrame()
) )
with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer: with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
@@ -1063,9 +1059,7 @@ async def aexport_data(
# Write header # Write header
mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n") mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n")
mdfile.write( mdfile.write(
"| " "| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n"
+ " | ".join(["---"] * len(entities_data[0].keys()))
+ " |\n"
) )
# Write rows # Write rows
@@ -1083,17 +1077,13 @@ async def aexport_data(
# Write header # Write header
mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n") mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n")
mdfile.write( mdfile.write(
"| " "| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n"
+ " | ".join(["---"] * len(relations_data[0].keys()))
+ " |\n"
) )
# Write rows # Write rows
for relation in relations_data: for relation in relations_data:
mdfile.write( mdfile.write(
"| " "| " + " | ".join(str(v) for v in relation.values()) + " |\n"
+ " | ".join(str(v) for v in relation.values())
+ " |\n"
) )
mdfile.write("\n\n") mdfile.write("\n\n")
else: else:
@@ -1103,9 +1093,7 @@ async def aexport_data(
mdfile.write("## Relationships\n\n") mdfile.write("## Relationships\n\n")
if relationships_data: if relationships_data:
# Write header # Write header
mdfile.write( mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n")
"| " + " | ".join(relationships_data[0].keys()) + " |\n"
)
mdfile.write( mdfile.write(
"| " "| "
+ " | ".join(["---"] * len(relationships_data[0].keys())) + " | ".join(["---"] * len(relationships_data[0].keys()))
@@ -1160,9 +1148,7 @@ async def aexport_data(
k: max(len(k), max(len(str(r[k])) for r in relations_data)) k: max(len(k), max(len(str(r[k])) for r in relations_data))
for k in relations_data[0] for k in relations_data[0]
} }
header = " ".join( header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0])
k.ljust(col_widths[k]) for k in relations_data[0]
)
txtfile.write(header + "\n") txtfile.write(header + "\n")
txtfile.write("-" * len(header) + "\n") txtfile.write("-" * len(header) + "\n")
@@ -1221,7 +1207,7 @@ def export_data(
) -> None: ) -> None:
""" """
Synchronously exports all entities, relations, and relationships to various formats. Synchronously exports all entities, relations, and relationships to various formats.
Args: Args:
chunk_entity_relation_graph: Graph storage instance for entities and relations chunk_entity_relation_graph: Graph storage instance for entities and relations
entities_vdb: Vector database storage for entities entities_vdb: Vector database storage for entities
@@ -1247,7 +1233,7 @@ def export_data(
relationships_vdb, relationships_vdb,
output_path, output_path,
file_format, file_format,
include_vector_data include_vector_data,
) )
) )

View File

@@ -7,11 +7,9 @@ from .kg.shared_storage import get_graph_db_lock
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
from .utils import compute_mdhash_id, logger, StorageNameSpace from .utils import compute_mdhash_id, logger, StorageNameSpace
async def adelete_by_entity( async def adelete_by_entity(
chunk_entity_relation_graph, chunk_entity_relation_graph, entities_vdb, relationships_vdb, entity_name: str
entities_vdb,
relationships_vdb,
entity_name: str
) -> None: ) -> None:
"""Asynchronously delete an entity and all its relationships. """Asynchronously delete an entity and all its relationships.
@@ -32,11 +30,16 @@ async def adelete_by_entity(
logger.info( logger.info(
f"Entity '{entity_name}' and its relationships have been deleted." f"Entity '{entity_name}' and its relationships have been deleted."
) )
await _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) await _delete_by_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting entity '{entity_name}': {e}") logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _delete_by_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity deletion is complete, ensures updates are persisted""" """Callback after entity deletion is complete, ensures updates are persisted"""
await asyncio.gather( await asyncio.gather(
*[ *[
@@ -49,11 +52,12 @@ async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_r
] ]
) )
async def adelete_by_relation( async def adelete_by_relation(
chunk_entity_relation_graph, chunk_entity_relation_graph,
relationships_vdb, relationships_vdb,
source_entity: str, source_entity: str,
target_entity: str target_entity: str,
) -> None: ) -> None:
"""Asynchronously delete a relation between two entities. """Asynchronously delete a relation between two entities.
@@ -97,6 +101,7 @@ async def adelete_by_relation(
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}" f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
) )
async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None: async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
"""Callback after relation deletion is complete, ensures updates are persisted""" """Callback after relation deletion is complete, ensures updates are persisted"""
await asyncio.gather( await asyncio.gather(
@@ -109,13 +114,14 @@ async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph)
] ]
) )
async def aedit_entity( async def aedit_entity(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
relationships_vdb, relationships_vdb,
entity_name: str, entity_name: str,
updated_data: dict[str, str], updated_data: dict[str, str],
allow_rename: bool = True allow_rename: bool = True,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously edit entity information. """Asynchronously edit entity information.
@@ -183,9 +189,7 @@ async def aedit_entity(
relations_to_update = [] relations_to_update = []
relations_to_delete = [] relations_to_delete = []
# Get all edges related to the original entity # Get all edges related to the original entity
edges = await chunk_entity_relation_graph.get_node_edges( edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
entity_name
)
if edges: if edges:
# Recreate edges for the new entity # Recreate edges for the new entity
for source, target in edges: for source, target in edges:
@@ -291,15 +295,25 @@ async def aedit_entity(
await entities_vdb.upsert(entity_data) await entities_vdb.upsert(entity_data)
# 4. Save changes # 4. Save changes
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) await _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(f"Entity '{entity_name}' successfully updated") logger.info(f"Entity '{entity_name}' successfully updated")
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True) return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name,
include_vector_data=True,
)
except Exception as e: except Exception as e:
logger.error(f"Error while editing entity '{entity_name}': {e}") logger.error(f"Error while editing entity '{entity_name}': {e}")
raise raise
async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity editing is complete, ensures updates are persisted""" """Callback after entity editing is complete, ensures updates are persisted"""
await asyncio.gather( await asyncio.gather(
*[ *[
@@ -312,13 +326,14 @@ async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relati
] ]
) )
async def aedit_relation( async def aedit_relation(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
relationships_vdb, relationships_vdb,
source_entity: str, source_entity: str,
target_entity: str, target_entity: str,
updated_data: dict[str, Any] updated_data: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously edit relation information. """Asynchronously edit relation information.
@@ -402,7 +417,11 @@ async def aedit_relation(
f"Relation from '{source_entity}' to '{target_entity}' successfully updated" f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
) )
return await get_relation_info( return await get_relation_info(
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True chunk_entity_relation_graph,
relationships_vdb,
source_entity,
target_entity,
include_vector_data=True,
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -410,6 +429,7 @@ async def aedit_relation(
) )
raise raise
async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None: async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
"""Callback after relation editing is complete, ensures updates are persisted""" """Callback after relation editing is complete, ensures updates are persisted"""
await asyncio.gather( await asyncio.gather(
@@ -422,12 +442,13 @@ async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) ->
] ]
) )
async def acreate_entity( async def acreate_entity(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
relationships_vdb, relationships_vdb,
entity_name: str, entity_name: str,
entity_data: dict[str, Any] entity_data: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously create a new entity. """Asynchronously create a new entity.
@@ -487,21 +508,29 @@ async def acreate_entity(
await entities_vdb.upsert(entity_data_for_vdb) await entities_vdb.upsert(entity_data_for_vdb)
# Save changes # Save changes
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) await _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(f"Entity '{entity_name}' successfully created") logger.info(f"Entity '{entity_name}' successfully created")
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True) return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name,
include_vector_data=True,
)
except Exception as e: except Exception as e:
logger.error(f"Error while creating entity '{entity_name}': {e}") logger.error(f"Error while creating entity '{entity_name}': {e}")
raise raise
async def acreate_relation( async def acreate_relation(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
relationships_vdb, relationships_vdb,
source_entity: str, source_entity: str,
target_entity: str, target_entity: str,
relation_data: dict[str, Any] relation_data: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously create a new relation between entities. """Asynchronously create a new relation between entities.
@@ -523,12 +552,8 @@ async def acreate_relation(
async with graph_db_lock: async with graph_db_lock:
try: try:
# Check if both entities exist # Check if both entities exist
source_exists = await chunk_entity_relation_graph.has_node( source_exists = await chunk_entity_relation_graph.has_node(source_entity)
source_entity target_exists = await chunk_entity_relation_graph.has_node(target_entity)
)
target_exists = await chunk_entity_relation_graph.has_node(
target_entity
)
if not source_exists: if not source_exists:
raise ValueError(f"Source entity '{source_entity}' does not exist") raise ValueError(f"Source entity '{source_entity}' does not exist")
@@ -594,7 +619,11 @@ async def acreate_relation(
f"Relation from '{source_entity}' to '{target_entity}' successfully created" f"Relation from '{source_entity}' to '{target_entity}' successfully created"
) )
return await get_relation_info( return await get_relation_info(
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True chunk_entity_relation_graph,
relationships_vdb,
source_entity,
target_entity,
include_vector_data=True,
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -602,6 +631,7 @@ async def acreate_relation(
) )
raise raise
async def amerge_entities( async def amerge_entities(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
@@ -657,18 +687,14 @@ async def amerge_entities(
# 1. Check if all source entities exist # 1. Check if all source entities exist
source_entities_data = {} source_entities_data = {}
for entity_name in source_entities: for entity_name in source_entities:
node_exists = await chunk_entity_relation_graph.has_node( node_exists = await chunk_entity_relation_graph.has_node(entity_name)
entity_name
)
if not node_exists: if not node_exists:
raise ValueError(f"Source entity '{entity_name}' does not exist") raise ValueError(f"Source entity '{entity_name}' does not exist")
node_data = await chunk_entity_relation_graph.get_node(entity_name) node_data = await chunk_entity_relation_graph.get_node(entity_name)
source_entities_data[entity_name] = node_data source_entities_data[entity_name] = node_data
# 2. Check if target entity exists and get its data if it does # 2. Check if target entity exists and get its data if it does
target_exists = await chunk_entity_relation_graph.has_node( target_exists = await chunk_entity_relation_graph.has_node(target_entity)
target_entity
)
existing_target_entity_data = {} existing_target_entity_data = {}
if target_exists: if target_exists:
existing_target_entity_data = ( existing_target_entity_data = (
@@ -693,9 +719,7 @@ async def amerge_entities(
all_relations = [] all_relations = []
for entity_name in source_entities: for entity_name in source_entities:
# Get all relationships of the source entities # Get all relationships of the source entities
edges = await chunk_entity_relation_graph.get_node_edges( edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
entity_name
)
if edges: if edges:
for src, tgt in edges: for src, tgt in edges:
# Ensure src is the current entity # Ensure src is the current entity
@@ -842,17 +866,25 @@ async def amerge_entities(
) )
# 10. Save changes # 10. Save changes
await _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) await _merge_entities_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info( logger.info(
f"Successfully merged {len(source_entities)} entities into '{target_entity}'" f"Successfully merged {len(source_entities)} entities into '{target_entity}'"
) )
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, target_entity, include_vector_data=True) return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
target_entity,
include_vector_data=True,
)
except Exception as e: except Exception as e:
logger.error(f"Error merging entities: {e}") logger.error(f"Error merging entities: {e}")
raise raise
def _merge_entity_attributes( def _merge_entity_attributes(
entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str] entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -902,6 +934,7 @@ def _merge_entity_attributes(
return merged_data return merged_data
def _merge_relation_attributes( def _merge_relation_attributes(
relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str] relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -925,9 +958,7 @@ def _merge_relation_attributes(
for key in all_keys: for key in all_keys:
# Get all values for this key # Get all values for this key
values = [ values = [
data.get(key) data.get(key) for data in relation_data_list if data.get(key) is not None
for data in relation_data_list
if data.get(key) is not None
] ]
if not values: if not values:
@@ -961,7 +992,10 @@ def _merge_relation_attributes(
return merged_data return merged_data
async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _merge_entities_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity merging is complete, ensures updates are persisted""" """Callback after entity merging is complete, ensures updates are persisted"""
await asyncio.gather( await asyncio.gather(
*[ *[
@@ -974,11 +1008,12 @@ async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_rel
] ]
) )
async def get_entity_info( async def get_entity_info(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
entity_name: str, entity_name: str,
include_vector_data: bool = False include_vector_data: bool = False,
) -> dict[str, str | None | dict[str, str]]: ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity""" """Get detailed information of an entity"""
@@ -1000,19 +1035,18 @@ async def get_entity_info(
return result return result
async def get_relation_info( async def get_relation_info(
chunk_entity_relation_graph, chunk_entity_relation_graph,
relationships_vdb, relationships_vdb,
src_entity: str, src_entity: str,
tgt_entity: str, tgt_entity: str,
include_vector_data: bool = False include_vector_data: bool = False,
) -> dict[str, str | None | dict[str, str]]: ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of a relationship""" """Get detailed information of a relationship"""
# Get information from the graph # Get information from the graph
edge_data = await chunk_entity_relation_graph.get_edge( edge_data = await chunk_entity_relation_graph.get_edge(src_entity, tgt_entity)
src_entity, tgt_entity
)
source_id = edge_data.get("source_id") if edge_data else None source_id = edge_data.get("source_id") if edge_data else None
result: dict[str, str | None | dict[str, str]] = { result: dict[str, str | None | dict[str, str]] = {