Fix linting

This commit is contained in:
yangdx
2025-04-14 12:08:56 +08:00
parent 40240bc79e
commit 5c1d4201f9
3 changed files with 136 additions and 108 deletions

View File

@@ -7,11 +7,9 @@ from .kg.shared_storage import get_graph_db_lock
from .prompt import GRAPH_FIELD_SEP
from .utils import compute_mdhash_id, logger, StorageNameSpace
async def adelete_by_entity(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
entity_name: str
chunk_entity_relation_graph, entities_vdb, relationships_vdb, entity_name: str
) -> None:
"""Asynchronously delete an entity and all its relationships.
@@ -32,11 +30,16 @@ async def adelete_by_entity(
logger.info(
f"Entity '{entity_name}' and its relationships have been deleted."
)
await _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _delete_by_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
except Exception as e:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _delete_by_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity deletion is complete, ensures updates are persisted"""
await asyncio.gather(
*[
@@ -49,11 +52,12 @@ async def _delete_by_entity_done(entities_vdb, relationships_vdb, chunk_entity_r
]
)
async def adelete_by_relation(
chunk_entity_relation_graph,
relationships_vdb,
source_entity: str,
target_entity: str
source_entity: str,
target_entity: str,
) -> None:
"""Asynchronously delete a relation between two entities.
@@ -97,6 +101,7 @@ async def adelete_by_relation(
f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}"
)
async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
"""Callback after relation deletion is complete, ensures updates are persisted"""
await asyncio.gather(
@@ -109,13 +114,14 @@ async def _delete_relation_done(relationships_vdb, chunk_entity_relation_graph)
]
)
async def aedit_entity(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
entity_name: str,
updated_data: dict[str, str],
allow_rename: bool = True
entity_name: str,
updated_data: dict[str, str],
allow_rename: bool = True,
) -> dict[str, Any]:
"""Asynchronously edit entity information.
@@ -183,9 +189,7 @@ async def aedit_entity(
relations_to_update = []
relations_to_delete = []
# Get all edges related to the original entity
edges = await chunk_entity_relation_graph.get_node_edges(
entity_name
)
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
if edges:
# Recreate edges for the new entity
for source, target in edges:
@@ -291,15 +295,25 @@ async def aedit_entity(
await entities_vdb.upsert(entity_data)
# 4. Save changes
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(f"Entity '{entity_name}' successfully updated")
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name,
include_vector_data=True,
)
except Exception as e:
logger.error(f"Error while editing entity '{entity_name}': {e}")
raise
async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity editing is complete, ensures updates are persisted"""
await asyncio.gather(
*[
@@ -312,13 +326,14 @@ async def _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relati
]
)
async def aedit_relation(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
source_entity: str,
target_entity: str,
updated_data: dict[str, Any]
source_entity: str,
target_entity: str,
updated_data: dict[str, Any],
) -> dict[str, Any]:
"""Asynchronously edit relation information.
@@ -402,7 +417,11 @@ async def aedit_relation(
f"Relation from '{source_entity}' to '{target_entity}' successfully updated"
)
return await get_relation_info(
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
chunk_entity_relation_graph,
relationships_vdb,
source_entity,
target_entity,
include_vector_data=True,
)
except Exception as e:
logger.error(
@@ -410,6 +429,7 @@ async def aedit_relation(
)
raise
async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) -> None:
"""Callback after relation editing is complete, ensures updates are persisted"""
await asyncio.gather(
@@ -422,12 +442,13 @@ async def _edit_relation_done(relationships_vdb, chunk_entity_relation_graph) ->
]
)
async def acreate_entity(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
entity_name: str,
entity_data: dict[str, Any]
entity_name: str,
entity_data: dict[str, Any],
) -> dict[str, Any]:
"""Asynchronously create a new entity.
@@ -487,21 +508,29 @@ async def acreate_entity(
await entities_vdb.upsert(entity_data_for_vdb)
# Save changes
await _edit_entity_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _edit_entity_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(f"Entity '{entity_name}' successfully created")
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, entity_name, include_vector_data=True)
return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name,
include_vector_data=True,
)
except Exception as e:
logger.error(f"Error while creating entity '{entity_name}': {e}")
raise
async def acreate_relation(
chunk_entity_relation_graph,
entities_vdb,
relationships_vdb,
source_entity: str,
target_entity: str,
relation_data: dict[str, Any]
source_entity: str,
target_entity: str,
relation_data: dict[str, Any],
) -> dict[str, Any]:
"""Asynchronously create a new relation between entities.
@@ -523,12 +552,8 @@ async def acreate_relation(
async with graph_db_lock:
try:
# Check if both entities exist
source_exists = await chunk_entity_relation_graph.has_node(
source_entity
)
target_exists = await chunk_entity_relation_graph.has_node(
target_entity
)
source_exists = await chunk_entity_relation_graph.has_node(source_entity)
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
if not source_exists:
raise ValueError(f"Source entity '{source_entity}' does not exist")
@@ -594,7 +619,11 @@ async def acreate_relation(
f"Relation from '{source_entity}' to '{target_entity}' successfully created"
)
return await get_relation_info(
chunk_entity_relation_graph, relationships_vdb, source_entity, target_entity, include_vector_data=True
chunk_entity_relation_graph,
relationships_vdb,
source_entity,
target_entity,
include_vector_data=True,
)
except Exception as e:
logger.error(
@@ -602,6 +631,7 @@ async def acreate_relation(
)
raise
async def amerge_entities(
chunk_entity_relation_graph,
entities_vdb,
@@ -657,18 +687,14 @@ async def amerge_entities(
# 1. Check if all source entities exist
source_entities_data = {}
for entity_name in source_entities:
node_exists = await chunk_entity_relation_graph.has_node(
entity_name
)
node_exists = await chunk_entity_relation_graph.has_node(entity_name)
if not node_exists:
raise ValueError(f"Source entity '{entity_name}' does not exist")
node_data = await chunk_entity_relation_graph.get_node(entity_name)
source_entities_data[entity_name] = node_data
# 2. Check if target entity exists and get its data if it does
target_exists = await chunk_entity_relation_graph.has_node(
target_entity
)
target_exists = await chunk_entity_relation_graph.has_node(target_entity)
existing_target_entity_data = {}
if target_exists:
existing_target_entity_data = (
@@ -693,9 +719,7 @@ async def amerge_entities(
all_relations = []
for entity_name in source_entities:
# Get all relationships of the source entities
edges = await chunk_entity_relation_graph.get_node_edges(
entity_name
)
edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
if edges:
for src, tgt in edges:
# Ensure src is the current entity
@@ -842,17 +866,25 @@ async def amerge_entities(
)
# 10. Save changes
await _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph)
await _merge_entities_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
)
logger.info(
f"Successfully merged {len(source_entities)} entities into '{target_entity}'"
)
return await get_entity_info(chunk_entity_relation_graph, entities_vdb, target_entity, include_vector_data=True)
return await get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
target_entity,
include_vector_data=True,
)
except Exception as e:
logger.error(f"Error merging entities: {e}")
raise
def _merge_entity_attributes(
entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]:
@@ -902,6 +934,7 @@ def _merge_entity_attributes(
return merged_data
def _merge_relation_attributes(
relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]:
@@ -925,9 +958,7 @@ def _merge_relation_attributes(
for key in all_keys:
# Get all values for this key
values = [
data.get(key)
for data in relation_data_list
if data.get(key) is not None
data.get(key) for data in relation_data_list if data.get(key) is not None
]
if not values:
@@ -961,7 +992,10 @@ def _merge_relation_attributes(
return merged_data
async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_relation_graph) -> None:
async def _merge_entities_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity merging is complete, ensures updates are persisted"""
await asyncio.gather(
*[
@@ -974,11 +1008,12 @@ async def _merge_entities_done(entities_vdb, relationships_vdb, chunk_entity_rel
]
)
async def get_entity_info(
chunk_entity_relation_graph,
entities_vdb,
entity_name: str,
include_vector_data: bool = False
entity_name: str,
include_vector_data: bool = False,
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity"""
@@ -1000,19 +1035,18 @@ async def get_entity_info(
return result
async def get_relation_info(
chunk_entity_relation_graph,
relationships_vdb,
src_entity: str,
tgt_entity: str,
include_vector_data: bool = False
src_entity: str,
tgt_entity: str,
include_vector_data: bool = False,
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of a relationship"""
# Get information from the graph
edge_data = await chunk_entity_relation_graph.get_edge(
src_entity, tgt_entity
)
edge_data = await chunk_entity_relation_graph.get_edge(src_entity, tgt_entity)
source_id = edge_data.get("source_id") if edge_data else None
result: dict[str, str | None | dict[str, str]] = {