Add missing await consume
This commit is contained in:
@@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
MAX_CONNECTION_POOL_SIZE = int(
|
||||
os.environ.get(
|
||||
"NEO4J_MAX_CONNECTION_POOL_SIZE",
|
||||
config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800
|
||||
config.get("neo4j", "connection_pool_size", fallback=50),
|
||||
)
|
||||
)
|
||||
CONNECTION_TIMEOUT = float(
|
||||
os.environ.get(
|
||||
"NEO4J_CONNECTION_TIMEOUT",
|
||||
config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0
|
||||
config.get("neo4j", "connection_timeout", fallback=30.0),
|
||||
),
|
||||
)
|
||||
CONNECTION_ACQUISITION_TIMEOUT = float(
|
||||
os.environ.get(
|
||||
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
|
||||
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0
|
||||
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
|
||||
),
|
||||
)
|
||||
MAX_TRANSACTION_RETRY_TIME = float(
|
||||
@@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = await self._ensure_label(node_id)
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = (
|
||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||
)
|
||||
result = await session.run(query)
|
||||
single_result = await result.single()
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
|
||||
)
|
||||
return single_result["node_exists"]
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = (
|
||||
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
@@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
result = await session.run(query)
|
||||
single_result = await result.single()
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
|
||||
)
|
||||
return single_result["edgeExists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
@@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
dict: Node properties if found
|
||||
None: If node not found
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
entity_name_label = await self._ensure_label(node_id)
|
||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||
result = await session.run(query)
|
||||
records = await result.fetch(2) # Get up to 2 records to check for duplicates
|
||||
records = await result.fetch(
|
||||
2
|
||||
) # Get up to 2 records to check for duplicates
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
if len(records) > 1:
|
||||
logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.")
|
||||
logger.warning(
|
||||
f"Multiple nodes found with label '{entity_name_label}'. Using first node."
|
||||
)
|
||||
if records:
|
||||
node = records[0]["n"]
|
||||
node_dict = dict(node)
|
||||
@@ -257,7 +261,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"""
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = f"""
|
||||
MATCH (n:`{entity_name_label}`)
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
@@ -272,7 +278,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return 0
|
||||
|
||||
if len(records) > 1:
|
||||
logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree")
|
||||
logger.warning(
|
||||
f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
|
||||
)
|
||||
|
||||
degree = records[0]["degree"]
|
||||
logger.debug(
|
||||
@@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
return degrees
|
||||
|
||||
async def check_duplicate_nodes(self) -> list[tuple[str, int]]:
|
||||
"""Find all labels that have multiple nodes
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WITH labels(n) as nodeLabels
|
||||
UNWIND nodeLabels as label
|
||||
WITH label, count(*) as node_count
|
||||
WHERE node_count > 1
|
||||
RETURN label, node_count
|
||||
ORDER BY node_count DESC
|
||||
"""
|
||||
result = await session.run(query)
|
||||
duplicates = []
|
||||
async for record in result:
|
||||
label = record["label"]
|
||||
count = record["node_count"]
|
||||
logger.info(f"Found {count} nodes with label: {label}")
|
||||
duplicates.append((label, count))
|
||||
return duplicates
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
@@ -328,13 +312,16 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
"""
|
||||
|
||||
result = await session.run(query)
|
||||
try:
|
||||
records = await result.fetch(2) # Get up to 2 records to check for duplicates
|
||||
if len(records) > 1:
|
||||
logger.warning(
|
||||
@@ -386,6 +373,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"keywords": None,
|
||||
"source_id": None,
|
||||
}
|
||||
finally:
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
query = f"""MATCH (n:`{node_label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
results = await session.run(query)
|
||||
edges = []
|
||||
try:
|
||||
@@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if source_label and target_label:
|
||||
edges.append((source_label, target_label))
|
||||
finally:
|
||||
await results.consume() # Ensure results are consumed even if processing fails
|
||||
await (
|
||||
results.consume()
|
||||
) # Ensure results are consumed even if processing fails
|
||||
|
||||
return edges
|
||||
|
||||
@@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
MERGE (n:`{label}`)
|
||||
SET n += $properties
|
||||
"""
|
||||
await tx.run(query, properties=properties)
|
||||
result = await tx.run(query, properties=properties)
|
||||
logger.debug(
|
||||
f"Upserted node with label '{label}' and properties: {properties}"
|
||||
)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
@@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
target_exists = await self.has_node(target_label)
|
||||
|
||||
if not source_exists:
|
||||
raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist")
|
||||
raise ValueError(
|
||||
f"Neo4j: source node with label '{source_label}' does not exist"
|
||||
)
|
||||
if not target_exists:
|
||||
raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist")
|
||||
raise ValueError(
|
||||
f"Neo4j: target node with label '{target_label}' does not exist"
|
||||
)
|
||||
|
||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||
query = f"""
|
||||
@@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
if label == "*":
|
||||
main_query = """
|
||||
@@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
visited_nodes.add(node_id)
|
||||
|
||||
# Add node data with label as ID
|
||||
result["nodes"].append({
|
||||
"id": current_label,
|
||||
"labels": current_label,
|
||||
"properties": node
|
||||
})
|
||||
result["nodes"].append(
|
||||
{"id": current_label, "labels": current_label, "properties": node}
|
||||
)
|
||||
|
||||
# Get connected nodes that meet the degree requirement
|
||||
# Note: We don't need to check a's degree since it's the current node
|
||||
@@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
|
||||
RETURN r, b
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
results = await session.run(query, {"min_degree": min_degree})
|
||||
async for record in results:
|
||||
# Handle edges
|
||||
@@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
b_node = record["b"]
|
||||
if b_node.labels: # Only process if target node has labels
|
||||
target_label = list(b_node.labels)[0]
|
||||
result["edges"].append({
|
||||
result["edges"].append(
|
||||
{
|
||||
"id": f"{current_label}_{target_label}",
|
||||
"type": rel.type,
|
||||
"source": current_label,
|
||||
"target": target_label,
|
||||
"properties": dict(rel)
|
||||
})
|
||||
"properties": dict(rel),
|
||||
}
|
||||
)
|
||||
visited_edges.add(edge_id)
|
||||
|
||||
# Continue traversal
|
||||
await traverse(target_label, current_depth + 1)
|
||||
else:
|
||||
logger.warning(f"Skipping edge {edge_id} due to missing labels on target node")
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
||||
)
|
||||
|
||||
await traverse(label, 0)
|
||||
return result
|
||||
@@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
["Person", "Company", ...] # Alphabetically sorted label list
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
|
||||
# query = "CALL db.labels() YIELD label RETURN label"
|
||||
|
||||
@@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async for record in result:
|
||||
labels.append(record["label"])
|
||||
finally:
|
||||
await result.consume() # Ensure results are consumed even if processing fails
|
||||
await (
|
||||
result.consume()
|
||||
) # Ensure results are consumed even if processing fails
|
||||
return labels
|
||||
|
||||
@retry(
|
||||
@@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
MATCH (n:`{label}`)
|
||||
DETACH DELETE n
|
||||
"""
|
||||
await tx.run(query)
|
||||
result = await tx.run(query)
|
||||
logger.debug(f"Deleted node with label '{label}'")
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
@@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
|
||||
DELETE r
|
||||
"""
|
||||
await tx.run(query)
|
||||
result = await tx.run(query)
|
||||
logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
|
Reference in New Issue
Block a user