Add missing await consume

This commit is contained in:
yangdx
2025-03-08 02:39:51 +08:00
parent af803f4e7a
commit c07b592e1b

View File

@@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage):
MAX_CONNECTION_POOL_SIZE = int( MAX_CONNECTION_POOL_SIZE = int(
os.environ.get( os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE", "NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 config.get("neo4j", "connection_pool_size", fallback=50),
) )
) )
CONNECTION_TIMEOUT = float( CONNECTION_TIMEOUT = float(
os.environ.get( os.environ.get(
"NEO4J_CONNECTION_TIMEOUT", "NEO4J_CONNECTION_TIMEOUT",
config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 config.get("neo4j", "connection_timeout", fallback=30.0),
), ),
) )
CONNECTION_ACQUISITION_TIMEOUT = float( CONNECTION_ACQUISITION_TIMEOUT = float(
os.environ.get( os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT", "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
), ),
) )
MAX_TRANSACTION_RETRY_TIME = float( MAX_TRANSACTION_RETRY_TIME = float(
@@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = await self._ensure_label(node_id) entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = ( query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
) )
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
)
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = ( query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
@@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
)
return single_result["edgeExists"] return single_result["edgeExists"]
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
@@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage):
dict: Node properties if found dict: Node properties if found
None: If node not found None: If node not found
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
entity_name_label = await self._ensure_label(node_id) entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query) result = await session.run(query)
records = await result.fetch(2) # Get up to 2 records to check for duplicates records = await result.fetch(
2
) # Get up to 2 records to check for duplicates
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
if len(records) > 1: if len(records) > 1:
logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") logger.warning(
f"Multiple nodes found with label '{entity_name_label}'. Using first node."
)
if records: if records:
node = records[0]["n"] node = records[0]["n"]
node_dict = dict(node) node_dict = dict(node)
@@ -257,7 +261,9 @@ class Neo4JStorage(BaseGraphStorage):
""" """
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f""" query = f"""
MATCH (n:`{entity_name_label}`) MATCH (n:`{entity_name_label}`)
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
@@ -272,7 +278,9 @@ class Neo4JStorage(BaseGraphStorage):
return 0 return 0
if len(records) > 1: if len(records) > 1:
logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") logger.warning(
f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
)
degree = records[0]["degree"] degree = records[0]["degree"]
logger.debug( logger.debug(
@@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
return degrees return degrees
async def check_duplicate_nodes(self) -> list[tuple[str, int]]:
"""Find all labels that have multiple nodes
Returns:
list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (n)
WITH labels(n) as nodeLabels
UNWIND nodeLabels as label
WITH label, count(*) as node_count
WHERE node_count > 1
RETURN label, node_count
ORDER BY node_count DESC
"""
result = await session.run(query)
duplicates = []
async for record in result:
label = record["label"]
count = record["node_count"]
logger.info(f"Found {count} nodes with label: {label}")
duplicates.append((label, count))
return duplicates
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
@@ -328,13 +312,16 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f""" query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
""" """
result = await session.run(query) result = await session.run(query)
try:
records = await result.fetch(2) # Get up to 2 records to check for duplicates records = await result.fetch(2) # Get up to 2 records to check for duplicates
if len(records) > 1: if len(records) > 1:
logger.warning( logger.warning(
@@ -386,6 +373,8 @@ class Neo4JStorage(BaseGraphStorage):
"keywords": None, "keywords": None,
"source_id": None, "source_id": None,
} }
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`) query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = await session.run(query) results = await session.run(query)
edges = [] edges = []
try: try:
@@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage):
if source_label and target_label: if source_label and target_label:
edges.append((source_label, target_label)) edges.append((source_label, target_label))
finally: finally:
await results.consume() # Ensure results are consumed even if processing fails await (
results.consume()
) # Ensure results are consumed even if processing fails
return edges return edges
@@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage):
MERGE (n:`{label}`) MERGE (n:`{label}`)
SET n += $properties SET n += $properties
""" """
await tx.run(query, properties=properties) result = await tx.run(query, properties=properties)
logger.debug( logger.debug(
f"Upserted node with label '{label}' and properties: {properties}" f"Upserted node with label '{label}' and properties: {properties}"
) )
await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
@@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage):
target_exists = await self.has_node(target_label) target_exists = await self.has_node(target_label)
if not source_exists: if not source_exists:
raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") raise ValueError(
f"Neo4j: source node with label '{source_label}' does not exist"
)
if not target_exists: if not target_exists:
raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") raise ValueError(
f"Neo4j: target node with label '{target_label}' does not exist"
)
async def _do_upsert_edge(tx: AsyncManagedTransaction): async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f""" query = f"""
@@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage):
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try: try:
if label == "*": if label == "*":
main_query = """ main_query = """
@@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage):
visited_nodes.add(node_id) visited_nodes.add(node_id)
# Add node data with label as ID # Add node data with label as ID
result["nodes"].append({ result["nodes"].append(
"id": current_label, {"id": current_label, "labels": current_label, "properties": node}
"labels": current_label, )
"properties": node
})
# Get connected nodes that meet the degree requirement # Get connected nodes that meet the degree requirement
# Note: We don't need to check a's degree since it's the current node # Note: We don't need to check a's degree since it's the current node
@@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage):
WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
RETURN r, b RETURN r, b
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = await session.run(query, {"min_degree": min_degree}) results = await session.run(query, {"min_degree": min_degree})
async for record in results: async for record in results:
# Handle edges # Handle edges
@@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage):
b_node = record["b"] b_node = record["b"]
if b_node.labels: # Only process if target node has labels if b_node.labels: # Only process if target node has labels
target_label = list(b_node.labels)[0] target_label = list(b_node.labels)[0]
result["edges"].append({ result["edges"].append(
{
"id": f"{current_label}_{target_label}", "id": f"{current_label}_{target_label}",
"type": rel.type, "type": rel.type,
"source": current_label, "source": current_label,
"target": target_label, "target": target_label,
"properties": dict(rel) "properties": dict(rel),
}) }
)
visited_edges.add(edge_id) visited_edges.add(edge_id)
# Continue traversal # Continue traversal
await traverse(target_label, current_depth + 1) await traverse(target_label, current_depth + 1)
else: else:
logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
await traverse(label, 0) await traverse(label, 0)
return result return result
@@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
["Person", "Company", ...] # Alphabetically sorted label list ["Person", "Company", ...] # Alphabetically sorted label list
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+) # Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label" # query = "CALL db.labels() YIELD label RETURN label"
@@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage):
async for record in result: async for record in result:
labels.append(record["label"]) labels.append(record["label"])
finally: finally:
await result.consume() # Ensure results are consumed even if processing fails await (
result.consume()
) # Ensure results are consumed even if processing fails
return labels return labels
@retry( @retry(
@@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (n:`{label}`) MATCH (n:`{label}`)
DETACH DELETE n DETACH DELETE n
""" """
await tx.run(query) result = await tx.run(query)
logger.debug(f"Deleted node with label '{label}'") logger.debug(f"Deleted node with label '{label}'")
await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
@@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
DELETE r DELETE r
""" """
await tx.run(query) result = await tx.run(query)
logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session: