Refactor Neo4J graph query with min_degree an inclusive match support

This commit is contained in:
yangdx
2025-03-08 01:20:36 +08:00
parent 0ee2e7fd48
commit af803f4e7a

View File

@@ -41,6 +41,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Set neo4j logger level to ERROR to suppress warning logs # Set neo4j logger level to ERROR to suppress warning logs
logging.getLogger("neo4j").setLevel(logging.ERROR) logging.getLogger("neo4j").setLevel(logging.ERROR)
@final @final
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@@ -63,19 +64,25 @@ class Neo4JStorage(BaseGraphStorage):
MAX_CONNECTION_POOL_SIZE = int( MAX_CONNECTION_POOL_SIZE = int(
os.environ.get( os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE", "NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=800), config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800
) )
) )
CONNECTION_TIMEOUT = float( CONNECTION_TIMEOUT = float(
os.environ.get( os.environ.get(
"NEO4J_CONNECTION_TIMEOUT", "NEO4J_CONNECTION_TIMEOUT",
config.get("neo4j", "connection_timeout", fallback=60.0), config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0
), ),
) )
CONNECTION_ACQUISITION_TIMEOUT = float( CONNECTION_ACQUISITION_TIMEOUT = float(
os.environ.get( os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT", "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
config.get("neo4j", "connection_acquisition_timeout", fallback=60.0), config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0
),
)
MAX_TRANSACTION_RETRY_TIME = float(
os.environ.get(
"NEO4J_MAX_TRANSACTION_RETRY_TIME",
config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
), ),
) )
DATABASE = os.environ.get( DATABASE = os.environ.get(
@@ -88,6 +95,7 @@ class Neo4JStorage(BaseGraphStorage):
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT, connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
) )
# Try to connect to the database # Try to connect to the database
@@ -174,16 +182,19 @@ class Neo4JStorage(BaseGraphStorage):
label: The label to validate label: The label to validate
""" """
clean_label = label.strip('"') clean_label = label.strip('"')
if not clean_label:
raise ValueError("Neo4j: Label cannot be empty")
return clean_label return clean_label
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = await self._ensure_label(node_id) entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = ( query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
) )
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed
logger.debug( logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
) )
@@ -193,13 +204,14 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = ( query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
) )
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed
logger.debug( logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
) )
@@ -215,13 +227,16 @@ class Neo4JStorage(BaseGraphStorage):
dict: Node properties if found dict: Node properties if found
None: If node not found None: If node not found
""" """
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
entity_name_label = await self._ensure_label(node_id) entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query) result = await session.run(query)
record = await result.single() records = await result.fetch(2) # Get up to 2 records to check for duplicates
if record: await result.consume() # Ensure result is fully consumed
node = record["n"] if len(records) > 1:
logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.")
if records:
node = records[0]["n"]
node_dict = dict(node) node_dict = dict(node)
logger.debug( logger.debug(
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
@@ -230,23 +245,40 @@ class Neo4JStorage(BaseGraphStorage):
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
If no node is found, returns 0.
Args:
node_id: The label of the node
Returns:
int: The number of relationships the node has, or 0 if no node found
"""
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = f""" query = f"""
MATCH (n:`{entity_name_label}`) MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount OPTIONAL MATCH (n)-[r]-()
RETURN n, COUNT(r) AS degree
""" """
result = await session.run(query) result = await session.run(query)
record = await result.single() records = await result.fetch(100)
if record: await result.consume() # Ensure result is fully consumed
edge_count = record["totalEdgeCount"]
logger.debug( if not records:
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" logger.warning(f"No node found with label '{entity_name_label}'")
) return 0
return edge_count
else: if len(records) > 1:
return None logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree")
degree = records[0]["degree"]
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
)
return degree
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('"') entity_name_label_source = src_id.strip('"')
@@ -264,6 +296,31 @@ class Neo4JStorage(BaseGraphStorage):
) )
return degrees return degrees
async def check_duplicate_nodes(self) -> list[tuple[str, int]]:
"""Find all labels that have multiple nodes
Returns:
list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
MATCH (n)
WITH labels(n) as nodeLabels
UNWIND nodeLabels as label
WITH label, count(*) as node_count
WHERE node_count > 1
RETURN label, node_count
ORDER BY node_count DESC
"""
result = await session.run(query)
duplicates = []
async for record in result:
label = record["label"]
count = record["node_count"]
logger.info(f"Found {count} nodes with label: {label}")
duplicates.append((label, count))
return duplicates
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
@@ -271,18 +328,21 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = f""" query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1
""" """
result = await session.run(query) result = await session.run(query)
record = await result.single() records = await result.fetch(2) # Get up to 2 records to check for duplicates
if record: if len(records) > 1:
logger.warning(
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
)
if records:
try: try:
result = dict(record["edge_properties"]) result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {result}") logger.debug(f"Result: {result}")
# Ensure required keys exist with defaults # Ensure required keys exist with defaults
required_keys = { required_keys = {
@@ -349,24 +409,27 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`) query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected"""
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
results = await session.run(query) results = await session.run(query)
edges = [] edges = []
async for record in results: try:
source_node = record["n"] async for record in results:
connected_node = record["connected"] source_node = record["n"]
connected_node = record["connected"]
source_label = ( source_label = (
list(source_node.labels)[0] if source_node.labels else None list(source_node.labels)[0] if source_node.labels else None
) )
target_label = ( target_label = (
list(connected_node.labels)[0] list(connected_node.labels)[0]
if connected_node and connected_node.labels if connected_node and connected_node.labels
else None else None
) )
if source_label and target_label: if source_label and target_label:
edges.append((source_label, target_label)) edges.append((source_label, target_label))
finally:
await results.consume() # Ensure results are consumed even if processing fails
return edges return edges
@@ -427,30 +490,46 @@ class Neo4JStorage(BaseGraphStorage):
) -> None: ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
Checks if both source and target nodes exist before creating the edge.
Args: Args:
source_node_id (str): Label of the source node (used as identifier) source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
Raises:
ValueError: If either source or target node does not exist
""" """
source_label = await self._ensure_label(source_node_id) source_label = await self._ensure_label(source_node_id)
target_label = await self._ensure_label(target_node_id) target_label = await self._ensure_label(target_node_id)
edge_properties = edge_data edge_properties = edge_data
# Check if both nodes exist
source_exists = await self.has_node(source_label)
target_exists = await self.has_node(target_label)
if not source_exists:
raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist")
if not target_exists:
raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist")
async def _do_upsert_edge(tx: AsyncManagedTransaction): async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f""" query = f"""
MATCH (source:`{source_label}`) MATCH (source:`{source_label}`)
WITH source WITH source
MATCH (target:`{target_label}`) MATCH (target:`{target_label}`)
MERGE (source)-[r:DIRECTED]->(target) MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties SET r += $properties
RETURN r RETURN r
""" """
result = await tx.run(query, properties=edge_properties) result = await tx.run(query, properties=edge_properties)
record = await result.single() try:
logger.debug( record = await result.single()
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" logger.debug(
) f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
)
finally:
await result.consume() # Ensure result is consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
@@ -463,145 +542,179 @@ class Neo4JStorage(BaseGraphStorage):
print("Implemented but never called.") print("Implemented but never called.")
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence (nodes containing the specified label string) 1. min_degree does not affect nodes directly connected to the matching nodes
2. Followed by nodes directly connected to the matching nodes 2. Label matching nodes take precedence
3. Finally, the degree of the nodes 3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args: Args:
node_label (str): String to match in node labels (will match any node containing this string in its label) node_label: Label of the starting node
max_depth (int, optional): Maximum depth of the graph. Defaults to 5. max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
Returns: Returns:
KnowledgeGraph: Complete connected subgraph for specified node KnowledgeGraph: Complete connected subgraph for specified node
""" """
label = node_label.strip('"') label = node_label.strip('"')
# Escape single quotes to prevent injection attacks
escaped_label = label.replace("'", "\\'")
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
try: try:
if label == "*": if label == "*":
main_query = """ main_query = """
MATCH (n) MATCH (n)
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
WITH n, count(r) AS degree WITH n, count(r) AS degree
WHERE degree >= $min_degree
ORDER BY degree DESC ORDER BY degree DESC
LIMIT $max_nodes LIMIT $max_nodes
WITH collect(n) AS nodes WITH collect({node: n}) AS filtered_nodes
MATCH (a)-[r]->(b) UNWIND filtered_nodes AS node_info
WHERE a IN nodes AND b IN nodes WITH collect(node_info.node) AS kept_nodes, filtered_nodes
RETURN nodes, collect(DISTINCT r) AS relationships MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
""" """
result_set = await session.run( result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES} main_query,
{"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
) )
else: else:
validate_query = f"""
MATCH (n)
WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}')
RETURN n LIMIT 1
"""
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(
f"No nodes containing '{label}' in their labels found!"
)
return result
# Main query uses partial matching # Main query uses partial matching
main_query = f""" main_query = """
MATCH (start) MATCH (start)
WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') WHERE any(label IN labels(start) WHERE
CASE
WHEN $inclusive THEN label CONTAINS $label
ELSE label = $label
END
)
WITH start WITH start
CALL apoc.path.subgraphAll(start, {{ CALL apoc.path.subgraphAll(start, {
relationshipFilter: '>', relationshipFilter: '',
minLevel: 0, minLevel: 0,
maxLevel: {max_depth}, maxLevel: $max_depth,
bfs: true bfs: true
}}) })
YIELD nodes, relationships YIELD nodes, relationships
WITH start, nodes, relationships WITH start, nodes, relationships
UNWIND nodes AS node UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-() OPTIONAL MATCH (node)-[r]-()
WITH node, count(r) AS degree, start, nodes, relationships, WITH node, count(r) AS degree, start, nodes, relationships
CASE WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
WHEN id(node) = id(start) THEN 2 ORDER BY
WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 CASE
ELSE 0 WHEN node = start THEN 3
END AS priority WHEN EXISTS((start)--(node)) THEN 2
ORDER BY priority DESC, degree DESC ELSE 1
END DESC,
degree DESC
LIMIT $max_nodes LIMIT $max_nodes
WITH collect(node) AS filtered_nodes, nodes, relationships WITH collect({node: node}) AS filtered_nodes
RETURN filtered_nodes AS nodes, UNWIND filtered_nodes AS node_info
[rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships WITH collect(node_info.node) AS kept_nodes, filtered_nodes
MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
""" """
result_set = await session.run( result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES} main_query,
{
"max_nodes": MAX_GRAPH_NODES,
"label": label,
"inclusive": inclusive,
"max_depth": max_depth,
"min_degree": min_degree,
},
) )
record = await result_set.single() try:
record = await result_set.single()
if record: if record:
# Handle nodes (compatible with multi-label cases) # Handle nodes (compatible with multi-label cases)
for node in record["nodes"]: for node_info in record["node_info"]:
# Use node ID + label combination as unique identifier node = node_info["node"]
node_id = node.id node_id = node.id
if node_id not in seen_nodes: if node_id not in seen_nodes:
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=f"{node_id}", id=f"{node_id}",
labels=list(node.labels), labels=list(node.labels),
properties=dict(node), properties=dict(node),
)
) )
) seen_nodes.add(node_id)
seen_nodes.add(node_id)
# Handle relationships (including direction information) # Handle relationships (including direction information)
for rel in record["relationships"]: for rel in record["relationships"]:
edge_id = rel.id edge_id = rel.id
if edge_id not in seen_edges: if edge_id not in seen_edges:
start = rel.start_node start = rel.start_node
end = rel.end_node end = rel.end_node
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=f"{edge_id}", id=f"{edge_id}",
type=rel.type, type=rel.type,
source=f"{start.id}", source=f"{start.id}",
target=f"{end.id}", target=f"{end.id}",
properties=dict(rel), properties=dict(rel),
)
) )
) seen_edges.add(edge_id)
seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
finally:
await result_set.consume() # Ensure result set is consumed
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
logger.error(f"APOC query failed: {str(e)}") logger.warning(
return await self._robust_fallback(label, max_depth) f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation"
)
if inclusive:
logger.warning(
"Inclusive search mode is not supported in recursive query, using exact matching"
)
return await self._robust_fallback(label, max_depth, min_degree)
return result return result
async def _robust_fallback( async def _robust_fallback(
self, label: str, max_depth: int self, label: str, max_depth: int, min_degree: int = 0
) -> Dict[str, List[Dict]]: ) -> Dict[str, List[Dict]]:
"""Enhanced fallback query solution""" """
Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures.
"""
result = {"nodes": [], "edges": []} result = {"nodes": [], "edges": []}
visited_nodes = set() visited_nodes = set()
visited_edges = set() visited_edges = set()
async def traverse(current_label: str, current_depth: int): async def traverse(current_label: str, current_depth: int):
# Check traversal limits
if current_depth > max_depth: if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= MAX_GRAPH_NODES:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
return return
# Get current node details # Get current node details
@@ -614,46 +727,46 @@ class Neo4JStorage(BaseGraphStorage):
return return
visited_nodes.add(node_id) visited_nodes.add(node_id)
# Add node data (with complete labels) # Add node data with label as ID
node_data = {k: v for k, v in node.items()} result["nodes"].append({
node_data["labels"] = [ "id": current_label,
current_label "labels": current_label,
] # Assume get_node method returns label information "properties": node
result["nodes"].append(node_data) })
# Get all outgoing and incoming edges # Get connected nodes that meet the degree requirement
# Note: We don't need to check a's degree since it's the current node
# and was already validated in the previous iteration
query = f""" query = f"""
MATCH (a)-[r]-(b) MATCH (a:`{current_label}`)-[r]-(b)
WHERE a:`{current_label}` OR b:`{current_label}` WITH r, b,
RETURN a, r, b, COUNT((b)--()) AS b_degree
CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
RETURN r, b
""" """
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
results = await session.run(query) results = await session.run(query, {"min_degree": min_degree})
async for record in results: async for record in results:
# Handle edges # Handle edges
rel = record["r"] rel = record["r"]
edge_id = f"{rel.id}_{rel.type}" edge_id = f"{rel.id}_{rel.type}"
if edge_id not in visited_edges: if edge_id not in visited_edges:
edge_data = dict(rel) b_node = record["b"]
edge_data.update( if b_node.labels: # Only process if target node has labels
{ target_label = list(b_node.labels)[0]
"source": list(record["a"].labels)[0], result["edges"].append({
"target": list(record["b"].labels)[0], "id": f"{current_label}_{target_label}",
"type": rel.type, "type": rel.type,
"direction": record["direction"], "source": current_label,
} "target": target_label,
) "properties": dict(rel)
result["edges"].append(edge_data) })
visited_edges.add(edge_id) visited_edges.add(edge_id)
# Recursively traverse adjacent nodes # Continue traversal
next_label = ( await traverse(target_label, current_depth + 1)
list(record["b"].labels)[0] else:
if record["direction"] == "OUTGOING" logger.warning(f"Skipping edge {edge_id} due to missing labels on target node")
else list(record["a"].labels)[0]
)
await traverse(next_label, current_depth + 1)
await traverse(label, 0) await traverse(label, 0)
return result return result
@@ -664,7 +777,7 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
["Person", "Company", ...] # Alphabetically sorted label list ["Person", "Company", ...] # Alphabetically sorted label list
""" """
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+) # Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label" # query = "CALL db.labels() YIELD label RETURN label"
@@ -679,8 +792,11 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
labels = [] labels = []
async for record in result: try:
labels.append(record["label"]) async for record in result:
labels.append(record["label"])
finally:
await result.consume() # Ensure results are consumed even if processing fails
return labels return labels
@retry( @retry(
@@ -763,7 +879,7 @@ class Neo4JStorage(BaseGraphStorage):
async def _do_delete_edge(tx: AsyncManagedTransaction): async def _do_delete_edge(tx: AsyncManagedTransaction):
query = f""" query = f"""
MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
DELETE r DELETE r
""" """
await tx.run(query) await tx.run(query)