Fix linting
This commit is contained in:
@@ -7,11 +7,9 @@ from .kg.shared_storage import get_graph_db_lock
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
from .utils import compute_mdhash_id, logger, StorageNameSpace
|
||||
|
||||
|
||||
async def adelete_by_entity(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
entity_name: str
|
||||
chunk_entity_relation_graph, entities_vdb, relationships_vdb, entity_name: str
|
||||
) -> None:
|
||||
"""Asynchronously delete an entity and all its relationships.
|
||||
|
||||
@@ -32,11 +30,16 @@ async def adelete_by_entity(
|
||||
logger.info(
|
||||
f"Entity '{entity_name}' and its relationships have been deleted."
|
||||
)
|
||||
await _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
||||
await _delete_by_entity_done(
|
||||
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||||
|
||||
async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
|
||||
|
||||
async def _delete_by_entity_done(
|
||||
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||
) -> None:
|
||||
"""Callback after entity deletion is complete, ensures updates are persisted"""
|
||||
await asyncio.gather(
|
||||
*[
|
||||
@@ -49,11 +52,12 @@ async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_r
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def adelete_by_relation(
|
||||
chunk_entity_relation_graph,
|
||||
relationships_vdb,
|
||||
source_entity: str,
|
||||
target_entity: str
|
||||
source_entity: str,
|
||||
target_entity: str,
|
||||
) -> None:
|
||||
"""Asynchronously delete a relation between two entities.
|
||||
|
||||
@@ -97,6 +101,7 @@ async def adelete_by_relation(
|
||||
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
|
||||
)
|
||||
|
||||
|
||||
async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
|
||||
"""Callback after relation deletion is complete, ensures updates are persisted"""
|
||||
await asyncio.gather(
|
||||
@@ -109,13 +114,14 @@ async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def aedit_entity(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
entity_name: str,
|
||||
updated_data: dict[str, str],
|
||||
allow_rename: bool = True
|
||||
entity_name: str,
|
||||
updated_data: dict[str, str],
|
||||
allow_rename: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously edit entity information.
|
||||
|
||||
@@ -183,9 +189,7 @@ async def aedit_entity(
|
||||
relations_to_update = []
|
||||
relations_to_delete = []
|
||||
# Get all edges related to the original entity
|
||||
edges = await chunk_entity_relation_graph.get_node_edges(
|
||||
entity_name
|
||||
)
|
||||
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
|
||||
if edges:
|
||||
# Recreate edges for the new entity
|
||||
for source, target in edges:
|
||||
@@ -291,15 +295,25 @@ async def aedit_entity(
|
||||
await entities_vdb.upsert(entity_data)
|
||||
|
||||
# 4. Save changes
|
||||
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
||||
await _edit_entity_done(
|
||||
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||
)
|
||||
|
||||
logger.info(f"Entity '{entity_name}' successfully updated")
|
||||
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
|
||||
return await get_entity_info(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
entity_name,
|
||||
include_vector_data=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while editing entity '{entity_name}': {e}")
|
||||
raise
|
||||
|
||||
async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
|
||||
|
||||
async def _edit_entity_done(
|
||||
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||
) -> None:
|
||||
"""Callback after entity editing is complete, ensures updates are persisted"""
|
||||
await asyncio.gather(
|
||||
*[
|
||||
@@ -312,13 +326,14 @@ async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relati
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def aedit_relation(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
source_entity: str,
|
||||
target_entity: str,
|
||||
updated_data: dict[str, Any]
|
||||
source_entity: str,
|
||||
target_entity: str,
|
||||
updated_data: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously edit relation information.
|
||||
|
||||
@@ -402,7 +417,11 @@ async def aedit_relation(
|
||||
f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
|
||||
)
|
||||
return await get_relation_info(
|
||||
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
|
||||
chunk_entity_relation_graph,
|
||||
relationships_vdb,
|
||||
source_entity,
|
||||
target_entity,
|
||||
include_vector_data=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -410,6 +429,7 @@ async def aedit_relation(
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
|
||||
"""Callback after relation editing is complete, ensures updates are persisted"""
|
||||
await asyncio.gather(
|
||||
@@ -422,12 +442,13 @@ async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) ->
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def acreate_entity(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
entity_name: str,
|
||||
entity_data: dict[str, Any]
|
||||
entity_name: str,
|
||||
entity_data: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously create a new entity.
|
||||
|
||||
@@ -487,21 +508,29 @@ async def acreate_entity(
|
||||
await entities_vdb.upsert(entity_data_for_vdb)
|
||||
|
||||
# Save changes
|
||||
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
||||
await _edit_entity_done(
|
||||
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||
)
|
||||
|
||||
logger.info(f"Entity '{entity_name}' successfully created")
|
||||
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
|
||||
return await get_entity_info(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
entity_name,
|
||||
include_vector_data=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while creating entity '{entity_name}': {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def acreate_relation(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
source_entity: str,
|
||||
target_entity: str,
|
||||
relation_data: dict[str, Any]
|
||||
source_entity: str,
|
||||
target_entity: str,
|
||||
relation_data: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously create a new relation between entities.
|
||||
|
||||
@@ -523,12 +552,8 @@ async def acreate_relation(
|
||||
async with graph_db_lock:
|
||||
try:
|
||||
# Check if both entities exist
|
||||
source_exists = await chunk_entity_relation_graph.has_node(
|
||||
source_entity
|
||||
)
|
||||
target_exists = await chunk_entity_relation_graph.has_node(
|
||||
target_entity
|
||||
)
|
||||
source_exists = await chunk_entity_relation_graph.has_node(source_entity)
|
||||
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
|
||||
|
||||
if not source_exists:
|
||||
raise ValueError(f"Source entity '{source_entity}' does not exist")
|
||||
@@ -594,7 +619,11 @@ async def acreate_relation(
|
||||
f"Relation from '{source_entity}' to '{target_entity}' successfully created"
|
||||
)
|
||||
return await get_relation_info(
|
||||
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
|
||||
chunk_entity_relation_graph,
|
||||
relationships_vdb,
|
||||
source_entity,
|
||||
target_entity,
|
||||
include_vector_data=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -602,6 +631,7 @@ async def acreate_relation(
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def amerge_entities(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
@@ -657,18 +687,14 @@ async def amerge_entities(
|
||||
# 1. Check if all source entities exist
|
||||
source_entities_data = {}
|
||||
for entity_name in source_entities:
|
||||
node_exists = await chunk_entity_relation_graph.has_node(
|
||||
entity_name
|
||||
)
|
||||
node_exists = await chunk_entity_relation_graph.has_node(entity_name)
|
||||
if not node_exists:
|
||||
raise ValueError(f"Source entity '{entity_name}' does not exist")
|
||||
node_data = await chunk_entity_relation_graph.get_node(entity_name)
|
||||
source_entities_data[entity_name] = node_data
|
||||
|
||||
# 2. Check if target entity exists and get its data if it does
|
||||
target_exists = await chunk_entity_relation_graph.has_node(
|
||||
target_entity
|
||||
)
|
||||
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
|
||||
existing_target_entity_data = {}
|
||||
if target_exists:
|
||||
existing_target_entity_data = (
|
||||
@@ -693,9 +719,7 @@ async def amerge_entities(
|
||||
all_relations = []
|
||||
for entity_name in source_entities:
|
||||
# Get all relationships of the source entities
|
||||
edges = await chunk_entity_relation_graph.get_node_edges(
|
||||
entity_name
|
||||
)
|
||||
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
|
||||
if edges:
|
||||
for src, tgt in edges:
|
||||
# Ensure src is the current entity
|
||||
@@ -842,17 +866,25 @@ async def amerge_entities(
|
||||
)
|
||||
|
||||
# 10. Save changes
|
||||
await _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
|
||||
await _merge_entities_done(
|
||||
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully merged {len(source_entities)} entities into '{target_entity}'"
|
||||
)
|
||||
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, target_entity, include_vector_data=True)
|
||||
return await get_entity_info(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
target_entity,
|
||||
include_vector_data=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error merging entities: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _merge_entity_attributes(
|
||||
entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
@@ -902,6 +934,7 @@ def _merge_entity_attributes(
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
def _merge_relation_attributes(
|
||||
relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
@@ -925,9 +958,7 @@ def _merge_relation_attributes(
|
||||
for key in all_keys:
|
||||
# Get all values for this key
|
||||
values = [
|
||||
data.get(key)
|
||||
for data in relation_data_list
|
||||
if data.get(key) is not None
|
||||
data.get(key) for data in relation_data_list if data.get(key) is not None
|
||||
]
|
||||
|
||||
if not values:
|
||||
@@ -961,7 +992,10 @@ def _merge_relation_attributes(
|
||||
|
||||
return merged_data
|
||||
|
||||
async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
|
||||
|
||||
async def _merge_entities_done(
|
||||
entities_vdb, relationships_vdb, chunk_entity_relation_graph
|
||||
) -> None:
|
||||
"""Callback after entity merging is complete, ensures updates are persisted"""
|
||||
await asyncio.gather(
|
||||
*[
|
||||
@@ -974,11 +1008,12 @@ async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_rel
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def get_entity_info(
|
||||
chunk_entity_relation_graph,
|
||||
entities_vdb,
|
||||
entity_name: str,
|
||||
include_vector_data: bool = False
|
||||
entity_name: str,
|
||||
include_vector_data: bool = False,
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of an entity"""
|
||||
|
||||
@@ -1000,19 +1035,18 @@ async def get_entity_info(
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_relation_info(
|
||||
chunk_entity_relation_graph,
|
||||
relationships_vdb,
|
||||
src_entity: str,
|
||||
tgt_entity: str,
|
||||
include_vector_data: bool = False
|
||||
src_entity: str,
|
||||
tgt_entity: str,
|
||||
include_vector_data: bool = False,
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of a relationship"""
|
||||
|
||||
# Get information from the graph
|
||||
edge_data = await chunk_entity_relation_graph.get_edge(
|
||||
src_entity, tgt_entity
|
||||
)
|
||||
edge_data = await chunk_entity_relation_graph.get_edge(src_entity, tgt_entity)
|
||||
source_id = edge_data.get("source_id") if edge_data else None
|
||||
|
||||
result: dict[str, str | None | dict[str, str]] = {
|
||||
|
Reference in New Issue
Block a user