Add duplicate edge upsert checking and logging
This commit is contained in:
@@ -412,9 +412,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
result = await session.run(query)
|
||||
try:
|
||||
records = await result.fetch(
|
||||
2
|
||||
) # Get up to 2 records to check for duplicates
|
||||
records = await result.fetch(2)
|
||||
|
||||
if len(records) > 1:
|
||||
logger.warning(
|
||||
@@ -552,7 +550,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
label = self._ensure_label(node_id)
|
||||
properties = node_data
|
||||
|
||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = f"""
|
||||
MERGE (n:`{label}`)
|
||||
SET n += $properties
|
||||
@@ -563,9 +563,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert)
|
||||
await session.execute_write(execute_upsert)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during upsert: {str(e)}")
|
||||
raise
|
||||
@@ -614,27 +612,39 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
f"Neo4j: target node with label '{target_label}' does not exist"
|
||||
)
|
||||
|
||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = f"""
|
||||
MATCH (source:`{source_label}`)
|
||||
WITH source
|
||||
MATCH (target:`{target_label}`)
|
||||
MERGE (source)-[r:DIRECTED]-(target)
|
||||
SET r += $properties
|
||||
RETURN r
|
||||
RETURN r, source, target
|
||||
"""
|
||||
result = await tx.run(query, properties=edge_properties)
|
||||
try:
|
||||
record = await result.single()
|
||||
records = await result.fetch(100)
|
||||
if len(records) > 1:
|
||||
source_nodes = [dict(r['source']) for r in records]
|
||||
target_nodes = [dict(r['target']) for r in records]
|
||||
logger.warning(
|
||||
f"Multiple edges created: found {len(records)} results for edge between "
|
||||
f"source label '{source_label}' and target label '{target_label}'. "
|
||||
f"Source nodes: {source_nodes}, "
|
||||
f"Target nodes: {target_nodes}. "
|
||||
"Using first edge only."
|
||||
)
|
||||
if records:
|
||||
logger.debug(
|
||||
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
|
||||
f"Upserted edge from '{source_label}' to '{target_label}' "
|
||||
f"with properties: {edge_properties}"
|
||||
)
|
||||
finally:
|
||||
await result.consume() # Ensure result is consumed
|
||||
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert_edge)
|
||||
await session.execute_write(execute_upsert)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge upsert: {str(e)}")
|
||||
raise
|
||||
|
Reference in New Issue
Block a user