Add missing await consume

This commit is contained in:
yangdx
2025-03-08 02:39:51 +08:00
parent af803f4e7a
commit c07b592e1b

View File

@@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage):
MAX_CONNECTION_POOL_SIZE = int(
os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800
config.get("neo4j", "connection_pool_size", fallback=50),
)
)
CONNECTION_TIMEOUT = float(
os.environ.get(
"NEO4J_CONNECTION_TIMEOUT",
config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0
config.get("neo4j", "connection_timeout", fallback=30.0),
),
)
CONNECTION_ACQUISITION_TIMEOUT = float(
os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
),
)
MAX_TRANSACTION_RETRY_TIME = float(
@@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool:
entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
result = await session.run(query)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
@@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
)
return single_result["edgeExists"]
async def get_node(self, node_id: str) -> dict[str, str] | None:
@@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage):
dict: Node properties if found
None: If node not found
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
records = await result.fetch(2) # Get up to 2 records to check for duplicates
records = await result.fetch(
2
) # Get up to 2 records to check for duplicates
await result.consume() # Ensure result is fully consumed
if len(records) > 1:
logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.")
logger.warning(
f"Multiple nodes found with label '{entity_name_label}'. Using first node."
)
if records:
node = records[0]["n"]
node_dict = dict(node)
@@ -248,16 +252,18 @@ class Neo4JStorage(BaseGraphStorage):
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
If no node is found, returns 0.
Args:
node_id: The label of the node
Returns:
int: The number of relationships the node has, or 0 if no node found
"""
entity_name_label = node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (n:`{entity_name_label}`)
OPTIONAL MATCH (n)-[r]-()
@@ -266,14 +272,16 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query)
records = await result.fetch(100)
await result.consume() # Ensure result is fully consumed
if not records:
logger.warning(f"No node found with label '{entity_name_label}'")
return 0
if len(records) > 1:
logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree")
logger.warning(
f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
)
degree = records[0]["degree"]
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
@@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
)
return degrees
async def check_duplicate_nodes(self) -> list[tuple[str, int]]:
"""Find all labels that have multiple nodes
Returns:
list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (n)
WITH labels(n) as nodeLabels
UNWIND nodeLabels as label
WITH label, count(*) as node_count
WHERE node_count > 1
RETURN label, node_count
ORDER BY node_count DESC
"""
result = await session.run(query)
duplicates = []
async for record in result:
label = record["label"]
count = record["node_count"]
logger.info(f"Found {count} nodes with label: {label}")
duplicates.append((label, count))
return duplicates
async def get_edge(
self, source_node_id: str, target_node_id: str
@@ -328,64 +312,69 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
"""
result = await session.run(query)
records = await result.fetch(2) # Get up to 2 records to check for duplicates
if len(records) > 1:
logger.warning(
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
try:
records = await result.fetch(2) # Get up to 2 records to check for duplicates
if len(records) > 1:
logger.warning(
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
)
if records:
try:
result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
)
if records:
try:
result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
)
# Return default edge properties when no edge found
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
# Return default edge properties when no edge found
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(
@@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = await session.run(query)
edges = []
try:
@@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage):
if source_label and target_label:
edges.append((source_label, target_label))
finally:
await results.consume() # Ensure results are consumed even if processing fails
await (
results.consume()
) # Ensure results are consumed even if processing fails
return edges
@@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage):
MERGE (n:`{label}`)
SET n += $properties
"""
await tx.run(query, properties=properties)
result = await tx.run(query, properties=properties)
logger.debug(
f"Upserted node with label '{label}' and properties: {properties}"
)
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
@@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage):
target_exists = await self.has_node(target_label)
if not source_exists:
raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist")
raise ValueError(
f"Neo4j: source node with label '{source_label}' does not exist"
)
if not target_exists:
raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist")
raise ValueError(
f"Neo4j: target node with label '{target_label}' does not exist"
)
async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f"""
@@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage):
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
if label == "*":
main_query = """
@@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage):
visited_nodes.add(node_id)
# Add node data with label as ID
result["nodes"].append({
"id": current_label,
"labels": current_label,
"properties": node
})
result["nodes"].append(
{"id": current_label, "labels": current_label, "properties": node}
)
# Get connected nodes that meet the degree requirement
# Note: We don't need to check a's degree since it's the current node
@@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage):
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
RETURN r, b
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = await session.run(query, {"min_degree": min_degree})
async for record in results:
# Handle edges
@@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage):
b_node = record["b"]
if b_node.labels: # Only process if target node has labels
target_label = list(b_node.labels)[0]
result["edges"].append({
"id": f"{current_label}_{target_label}",
"type": rel.type,
"source": current_label,
"target": target_label,
"properties": dict(rel)
})
result["edges"].append(
{
"id": f"{current_label}_{target_label}",
"type": rel.type,
"source": current_label,
"target": target_label,
"properties": dict(rel),
}
)
visited_edges.add(edge_id)
# Continue traversal
await traverse(target_label, current_depth + 1)
else:
logger.warning(f"Skipping edge {edge_id} due to missing labels on target node")
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
await traverse(label, 0)
return result
@@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
@@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage):
async for record in result:
labels.append(record["label"])
finally:
await result.consume() # Ensure results are consumed even if processing fails
await (
result.consume()
) # Ensure results are consumed even if processing fails
return labels
@retry(
@@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (n:`{label}`)
DETACH DELETE n
"""
await tx.run(query)
result = await tx.run(query)
logger.debug(f"Deleted node with label '{label}'")
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
@@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
DELETE r
"""
await tx.run(query)
result = await tx.run(query)
logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session: