Add missing await consume

This commit is contained in:
yangdx
2025-03-08 02:39:51 +08:00
parent af803f4e7a
commit c07b592e1b

View File

@@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage):
MAX_CONNECTION_POOL_SIZE = int( MAX_CONNECTION_POOL_SIZE = int(
os.environ.get( os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE", "NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 config.get("neo4j", "connection_pool_size", fallback=50),
) )
) )
CONNECTION_TIMEOUT = float( CONNECTION_TIMEOUT = float(
os.environ.get( os.environ.get(
"NEO4J_CONNECTION_TIMEOUT", "NEO4J_CONNECTION_TIMEOUT",
config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 config.get("neo4j", "connection_timeout", fallback=30.0),
), ),
) )
CONNECTION_ACQUISITION_TIMEOUT = float( CONNECTION_ACQUISITION_TIMEOUT = float(
os.environ.get( os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT", "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
), ),
) )
MAX_TRANSACTION_RETRY_TIME = float( MAX_TRANSACTION_RETRY_TIME = float(
@@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = await self._ensure_label(node_id) entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = ( query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
) )
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
)
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = ( query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
@@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
)
return single_result["edgeExists"] return single_result["edgeExists"]
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
@@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage):
dict: Node properties if found dict: Node properties if found
None: If node not found None: If node not found
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
entity_name_label = await self._ensure_label(node_id) entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query) result = await session.run(query)
records = await result.fetch(2) # Get up to 2 records to check for duplicates records = await result.fetch(
2
) # Get up to 2 records to check for duplicates
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
if len(records) > 1: if len(records) > 1:
logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") logger.warning(
f"Multiple nodes found with label '{entity_name_label}'. Using first node."
)
if records: if records:
node = records[0]["n"] node = records[0]["n"]
node_dict = dict(node) node_dict = dict(node)
@@ -248,16 +252,18 @@ class Neo4JStorage(BaseGraphStorage):
"""Get the degree (number of relationships) of a node with the given label. """Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node. If multiple nodes have the same label, returns the degree of the first node.
If no node is found, returns 0. If no node is found, returns 0.
Args: Args:
node_id: The label of the node node_id: The label of the node
Returns: Returns:
int: The number of relationships the node has, or 0 if no node found int: The number of relationships the node has, or 0 if no node found
""" """
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f""" query = f"""
MATCH (n:`{entity_name_label}`) MATCH (n:`{entity_name_label}`)
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
@@ -266,14 +272,16 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
records = await result.fetch(100) records = await result.fetch(100)
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
if not records: if not records:
logger.warning(f"No node found with label '{entity_name_label}'") logger.warning(f"No node found with label '{entity_name_label}'")
return 0 return 0
if len(records) > 1: if len(records) > 1:
logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") logger.warning(
f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
)
degree = records[0]["degree"] degree = records[0]["degree"]
logger.debug( logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
@@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
return degrees return degrees
async def check_duplicate_nodes(self) -> list[tuple[str, int]]:
"""Find all labels that have multiple nodes
Returns:
list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (n)
WITH labels(n) as nodeLabels
UNWIND nodeLabels as label
WITH label, count(*) as node_count
WHERE node_count > 1
RETURN label, node_count
ORDER BY node_count DESC
"""
result = await session.run(query)
duplicates = []
async for record in result:
label = record["label"]
count = record["node_count"]
logger.info(f"Found {count} nodes with label: {label}")
duplicates.append((label, count))
return duplicates
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
@@ -328,64 +312,69 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f""" query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
""" """
result = await session.run(query) result = await session.run(query)
records = await result.fetch(2) # Get up to 2 records to check for duplicates try:
if len(records) > 1: records = await result.fetch(2) # Get up to 2 records to check for duplicates
logger.warning( if len(records) > 1:
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." logger.warning(
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
)
if records:
try:
result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
) )
if records: # Return default edge properties when no edge found
try: return {
result = dict(records[0]["edge_properties"]) "weight": 0.0,
logger.debug(f"Result: {result}") "description": None,
# Ensure required keys exist with defaults "keywords": None,
required_keys = { "source_id": None,
"weight": 0.0, }
"source_id": None, finally:
"description": None, await result.consume() # Ensure result is fully consumed
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
)
# Return default edge properties when no edge found
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`) query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = await session.run(query) results = await session.run(query)
edges = [] edges = []
try: try:
@@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage):
if source_label and target_label: if source_label and target_label:
edges.append((source_label, target_label)) edges.append((source_label, target_label))
finally: finally:
await results.consume() # Ensure results are consumed even if processing fails await (
results.consume()
) # Ensure results are consumed even if processing fails
return edges return edges
@@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage):
MERGE (n:`{label}`) MERGE (n:`{label}`)
SET n += $properties SET n += $properties
""" """
await tx.run(query, properties=properties) result = await tx.run(query, properties=properties)
logger.debug( logger.debug(
f"Upserted node with label '{label}' and properties: {properties}" f"Upserted node with label '{label}' and properties: {properties}"
) )
await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
@@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage):
target_exists = await self.has_node(target_label) target_exists = await self.has_node(target_label)
if not source_exists: if not source_exists:
raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") raise ValueError(
f"Neo4j: source node with label '{source_label}' does not exist"
)
if not target_exists: if not target_exists:
raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") raise ValueError(
f"Neo4j: target node with label '{target_label}' does not exist"
)
async def _do_upsert_edge(tx: AsyncManagedTransaction): async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f""" query = f"""
@@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage):
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try: try:
if label == "*": if label == "*":
main_query = """ main_query = """
@@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage):
visited_nodes.add(node_id) visited_nodes.add(node_id)
# Add node data with label as ID # Add node data with label as ID
result["nodes"].append({ result["nodes"].append(
"id": current_label, {"id": current_label, "labels": current_label, "properties": node}
"labels": current_label, )
"properties": node
})
# Get connected nodes that meet the degree requirement # Get connected nodes that meet the degree requirement
# Note: We don't need to check a's degree since it's the current node # Note: We don't need to check a's degree since it's the current node
@@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage):
WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
RETURN r, b RETURN r, b
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = await session.run(query, {"min_degree": min_degree}) results = await session.run(query, {"min_degree": min_degree})
async for record in results: async for record in results:
# Handle edges # Handle edges
@@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage):
b_node = record["b"] b_node = record["b"]
if b_node.labels: # Only process if target node has labels if b_node.labels: # Only process if target node has labels
target_label = list(b_node.labels)[0] target_label = list(b_node.labels)[0]
result["edges"].append({ result["edges"].append(
"id": f"{current_label}_{target_label}", {
"type": rel.type, "id": f"{current_label}_{target_label}",
"source": current_label, "type": rel.type,
"target": target_label, "source": current_label,
"properties": dict(rel) "target": target_label,
}) "properties": dict(rel),
}
)
visited_edges.add(edge_id) visited_edges.add(edge_id)
# Continue traversal # Continue traversal
await traverse(target_label, current_depth + 1) await traverse(target_label, current_depth + 1)
else: else:
logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
await traverse(label, 0) await traverse(label, 0)
return result return result
@@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
["Person", "Company", ...] # Alphabetically sorted label list ["Person", "Company", ...] # Alphabetically sorted label list
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+) # Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label" # query = "CALL db.labels() YIELD label RETURN label"
@@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage):
async for record in result: async for record in result:
labels.append(record["label"]) labels.append(record["label"])
finally: finally:
await result.consume() # Ensure results are consumed even if processing fails await (
result.consume()
) # Ensure results are consumed even if processing fails
return labels return labels
@retry( @retry(
@@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (n:`{label}`) MATCH (n:`{label}`)
DETACH DELETE n DETACH DELETE n
""" """
await tx.run(query) result = await tx.run(query)
logger.debug(f"Deleted node with label '{label}'") logger.debug(f"Deleted node with label '{label}'")
await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
@@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
DELETE r DELETE r
""" """
await tx.run(query) result = await tx.run(query)
logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session: