fix: duplicate nodes for same entity(label) problem in Neo4j

- Add entity_id field as key in Neo4j nodes
- Use  entity_id for nodes retrival and upsert
This commit is contained in:
yangdx
2025-03-09 00:24:55 +08:00
parent 73452e63fa
commit 18c0770409
2 changed files with 74 additions and 34 deletions

View File

@@ -280,12 +280,10 @@ class Neo4JStorage(BaseGraphStorage):
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try: try:
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
result = await session.run(query) result = await session.run(query, entity_id=entity_name_label)
try: try:
records = await result.fetch( records = await result.fetch(2) # Get 2 records for duplication check
2
) # Get up to 2 records to check for duplicates
if len(records) > 1: if len(records) > 1:
logger.warning( logger.warning(
@@ -549,12 +547,14 @@ class Neo4JStorage(BaseGraphStorage):
""" """
label = self._ensure_label(node_id) label = self._ensure_label(node_id)
properties = node_data properties = node_data
if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = f""" query = f"""
MERGE (n:`{label}`) MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
SET n += $properties SET n += $properties
""" """
result = await tx.run(query, properties=properties) result = await tx.run(query, properties=properties)
@@ -568,6 +568,56 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error during upsert: {str(e)}") logger.error(f"Error during upsert: {str(e)}")
raise raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def _get_unique_node_entity_id(self, node_label: str) -> str:
"""
Get the entity_id of a node with the given label, ensuring the node is unique.
Args:
node_label (str): Label of the node to check
Returns:
str: The entity_id of the unique node
Raises:
ValueError: If no node with the given label exists or if multiple nodes have the same label
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (n:`{node_label}`)
RETURN n, count(n) as node_count
"""
result = await session.run(query)
try:
records = await result.fetch(2) # We only need to know if there are 0, 1, or >1 nodes
if not records or records[0]["node_count"] == 0:
raise ValueError(f"Neo4j: node with label '{node_label}' does not exist")
if records[0]["node_count"] > 1:
raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node")
node = records[0]["n"]
if "entity_id" not in node:
raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property")
return node["entity_id"]
finally:
await result.consume() # Ensure result is fully consumed
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -585,7 +635,8 @@ class Neo4JStorage(BaseGraphStorage):
) -> None: ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
Checks if both source and target nodes exist before creating the edge. Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes.
Args: Args:
source_node_id (str): Label of the source node (used as identifier) source_node_id (str): Label of the source node (used as identifier)
@@ -593,52 +644,39 @@ class Neo4JStorage(BaseGraphStorage):
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
Raises: Raises:
ValueError: If either source or target node does not exist ValueError: If either source or target node does not exist or is not unique
""" """
source_label = self._ensure_label(source_node_id) source_label = self._ensure_label(source_node_id)
target_label = self._ensure_label(target_node_id) target_label = self._ensure_label(target_node_id)
edge_properties = edge_data edge_properties = edge_data
# Check if both nodes exist # Get entity_ids for source and target nodes, ensuring they are unique
source_exists = await self.has_node(source_label) source_entity_id = await self._get_unique_node_entity_id(source_label)
target_exists = await self.has_node(target_label) target_entity_id = await self._get_unique_node_entity_id(target_label)
if not source_exists:
raise ValueError(
f"Neo4j: source node with label '{source_label}' does not exist"
)
if not target_exists:
raise ValueError(
f"Neo4j: target node with label '{target_label}' does not exist"
)
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = f""" query = f"""
MATCH (source:`{source_label}`) MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
WITH source WITH source
MATCH (target:`{target_label}`) MATCH (target:`{target_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target) MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties SET r += $properties
RETURN r, source, target RETURN r, source, target
""" """
result = await tx.run(query, properties=edge_properties) result = await tx.run(
query,
source_entity_id=source_entity_id,
target_entity_id=target_entity_id,
properties=edge_properties
)
try: try:
records = await result.fetch(100) records = await result.fetch(100)
if len(records) > 1:
source_nodes = [dict(r['source']) for r in records]
target_nodes = [dict(r['target']) for r in records]
logger.warning(
f"Multiple edges created: found {len(records)} results for edge between "
f"source label '{source_label}' and target label '{target_label}'. "
f"Source nodes: {source_nodes}, "
f"Target nodes: {target_nodes}. "
"Using first edge only."
)
if records: if records:
logger.debug( logger.debug(
f"Upserted edge from '{source_label}' to '{target_label}' " f"Upserted edge from '{source_label}' (entity_id: {source_entity_id}) "
f"to '{target_label}' (entity_id: {target_entity_id}) "
f"with properties: {edge_properties}" f"with properties: {edge_properties}"
) )
finally: finally:

View File

@@ -220,6 +220,7 @@ async def _merge_nodes_then_upsert(
entity_name, description, global_config entity_name, description, global_config
) )
node_data = dict( node_data = dict(
entity_id=entity_name,
entity_type=entity_type, entity_type=entity_type,
description=description, description=description,
source_id=source_id, source_id=source_id,
@@ -301,6 +302,7 @@ async def _merge_edges_then_upsert(
await knowledge_graph_inst.upsert_node( await knowledge_graph_inst.upsert_node(
need_insert_id, need_insert_id,
node_data={ node_data={
"entity_id": need_insert_id,
"source_id": source_id, "source_id": source_id,
"description": description, "description": description,
"entity_type": "UNKNOWN", "entity_type": "UNKNOWN",