From af803f4e7ad3267fcd184fd6c3914b4c6b2c6bef Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 01:20:36 +0800 Subject: [PATCH] Refactor Neo4J graph query with min_degree an inclusive match support --- lightrag/kg/neo4j_impl.py | 434 ++++++++++++++++++++++++-------------- 1 file changed, 275 insertions(+), 159 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 265c0347..f6567249 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -41,6 +41,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) # Set neo4j logger level to ERROR to suppress warning logs logging.getLogger("neo4j").setLevel(logging.ERROR) + @final @dataclass class Neo4JStorage(BaseGraphStorage): @@ -63,19 +64,25 @@ class Neo4JStorage(BaseGraphStorage): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=800), + config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=60.0), + config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=60.0), + config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 + ), + ) + MAX_TRANSACTION_RETRY_TIME = float( + os.environ.get( + "NEO4J_MAX_TRANSACTION_RETRY_TIME", + config.get("neo4j", "max_transaction_retry_time", fallback=30.0), ), ) DATABASE = os.environ.get( @@ -88,6 +95,7 @@ class Neo4JStorage(BaseGraphStorage): max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, connection_timeout=CONNECTION_TIMEOUT, connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, + max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, ) # Try to connect to the database @@ -169,21 +177,24 @@ class Neo4JStorage(BaseGraphStorage): async def _ensure_label(self, label: str) -> str: """Ensure a label is valid - + Args: label: The label to validate """ clean_label = label.strip('"') + if not clean_label: + raise ValueError("Neo4j: Label cannot be empty") return clean_label async def has_node(self, node_id: str) -> bool: entity_name_label = await self._ensure_label(node_id) - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) result = await session.run(query) single_result = await result.single() + await result.consume() # Ensure result is fully consumed logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" ) @@ -193,13 +204,14 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" ) result = await session.run(query) single_result = await result.single() + await result.consume() # Ensure result is fully consumed logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" ) @@ -215,13 +227,16 @@ class Neo4JStorage(BaseGraphStorage): dict: Node properties if found None: If node not found """ - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) - record = await result.single() - if record: - node = record["n"] + records = await result.fetch(2) # Get up to 2 records to check for duplicates + await result.consume() # Ensure result is fully consumed + if len(records) > 1: + logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") + if records: + node = records[0]["n"] node_dict = dict(node) logger.debug( f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" @@ -230,23 +245,40 @@ class Neo4JStorage(BaseGraphStorage): return None async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + """ entity_name_label = node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = f""" MATCH (n:`{entity_name_label}`) - RETURN COUNT{{ (n)--() }} AS totalEdgeCount + OPTIONAL MATCH (n)-[r]-() + RETURN n, COUNT(r) AS degree """ result = await session.run(query) - record = await result.single() - if record: - edge_count = record["totalEdgeCount"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" - ) - return edge_count - else: - return None + records = await result.fetch(100) + await result.consume() # Ensure result is fully consumed + + if not records: + logger.warning(f"No node found with label '{entity_name_label}'") + return 0 + + if len(records) > 1: + logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") + + degree = records[0]["degree"] + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" + ) + return degree async def edge_degree(self, src_id: str, tgt_id: str) -> int: entity_name_label_source = src_id.strip('"') @@ -264,6 +296,31 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees + async def check_duplicate_nodes(self) -> list[tuple[str, int]]: + """Find all labels that have multiple nodes + + Returns: + list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes + """ + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n) + WITH labels(n) as nodeLabels + UNWIND nodeLabels as label + WITH label, count(*) as node_count + WHERE node_count > 1 + RETURN label, node_count + ORDER BY node_count DESC + """ + result = await session.run(query) + duplicates = [] + async for record in result: + label = record["label"] + count = record["node_count"] + logger.info(f"Found {count} nodes with label: {label}") + duplicates.append((label, count)) + return duplicates + async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -271,18 +328,21 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) + MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties - LIMIT 1 """ result = await session.run(query) - record = await result.single() - if record: + records = await result.fetch(2) # Get up to 2 records to check for duplicates + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: try: - result = dict(record["edge_properties"]) + result = dict(records[0]["edge_properties"]) logger.debug(f"Result: {result}") # Ensure required keys exist with defaults required_keys = { @@ -349,24 +409,27 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: results = await session.run(query) edges = [] - async for record in results: - source_node = record["n"] - connected_node = record["connected"] + try: + async for record in results: + source_node = record["n"] + connected_node = record["connected"] - source_label = ( - list(source_node.labels)[0] if source_node.labels else None - ) - target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None - ) + source_label = ( + list(source_node.labels)[0] if source_node.labels else None + ) + target_label = ( + list(connected_node.labels)[0] + if connected_node and connected_node.labels + else None + ) - if source_label and target_label: - edges.append((source_label, target_label)) + if source_label and target_label: + edges.append((source_label, target_label)) + finally: + await results.consume() # Ensure results are consumed even if processing fails return edges @@ -427,30 +490,46 @@ class Neo4JStorage(BaseGraphStorage): ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. + Checks if both source and target nodes exist before creating the edge. Args: source_node_id (str): Label of the source node (used as identifier) target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge + + Raises: + ValueError: If either source or target node does not exist """ source_label = await self._ensure_label(source_node_id) target_label = await self._ensure_label(target_node_id) edge_properties = edge_data + # Check if both nodes exist + source_exists = await self.has_node(source_label) + target_exists = await self.has_node(target_label) + + if not source_exists: + raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") + if not target_exists: + raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") + async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" MATCH (source:`{source_label}`) WITH source MATCH (target:`{target_label}`) - MERGE (source)-[r:DIRECTED]->(target) + MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r """ result = await tx.run(query, properties=edge_properties) - record = await result.single() - logger.debug( - f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" - ) + try: + record = await result.single() + logger.debug( + f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" + ) + finally: + await result.consume() # Ensure result is consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -463,145 +542,179 @@ class Neo4JStorage(BaseGraphStorage): print("Implemented but never called.") async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 3, + min_degree: int = 0, + inclusive: bool = False, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence (nodes containing the specified label string) - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes + 1. min_degree does not affect nodes directly connected to the matching nodes + 2. Label matching nodes take precedence + 3. Followed by nodes directly connected to the matching nodes + 4. Finally, the degree of the nodes Args: - node_label (str): String to match in node labels (will match any node containing this string in its label) - max_depth (int, optional): Maximum depth of the graph. Defaults to 5. + node_label: Label of the starting node + max_depth: Maximum depth of the subgraph + min_degree: Minimum degree of nodes to include. Defaults to 0 + inclusive: Do an inclusive search if true Returns: KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') - # Escape single quotes to prevent injection attacks - escaped_label = label.replace("'", "\\'") result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: try: if label == "*": main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() WITH n, count(r) AS degree + WHERE degree >= $min_degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect(n) AS nodes - MATCH (a)-[r]->(b) - WHERE a IN nodes AND b IN nodes - RETURN nodes, collect(DISTINCT r) AS relationships + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree}, ) else: - validate_query = f""" - MATCH (n) - WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}') - RETURN n LIMIT 1 - """ - validate_result = await session.run(validate_query) - if not await validate_result.single(): - logger.warning( - f"No nodes containing '{label}' in their labels found!" - ) - return result - # Main query uses partial matching - main_query = f""" + main_query = """ MATCH (start) - WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') + WHERE any(label IN labels(start) WHERE + CASE + WHEN $inclusive THEN label CONTAINS $label + ELSE label = $label + END + ) WITH start - CALL apoc.path.subgraphAll(start, {{ - relationshipFilter: '>', + CALL apoc.path.subgraphAll(start, { + relationshipFilter: '', minLevel: 0, - maxLevel: {max_depth}, + maxLevel: $max_depth, bfs: true - }}) + }) YIELD nodes, relationships WITH start, nodes, relationships UNWIND nodes AS node OPTIONAL MATCH (node)-[r]-() - WITH node, count(r) AS degree, start, nodes, relationships, - CASE - WHEN id(node) = id(start) THEN 2 - WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 - ELSE 0 - END AS priority - ORDER BY priority DESC, degree DESC + WITH node, count(r) AS degree, start, nodes, relationships + WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree + ORDER BY + CASE + WHEN node = start THEN 3 + WHEN EXISTS((start)--(node)) THEN 2 + ELSE 1 + END DESC, + degree DESC LIMIT $max_nodes - WITH collect(node) AS filtered_nodes, nodes, relationships - RETURN filtered_nodes AS nodes, - [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships + WITH collect({node: node}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + { + "max_nodes": MAX_GRAPH_NODES, + "label": label, + "inclusive": inclusive, + "max_depth": max_depth, + "min_degree": min_degree, + }, ) - record = await result_set.single() + try: + record = await result_set.single() - if record: - # Handle nodes (compatible with multi-label cases) - for node in record["nodes"]: - # Use node ID + label combination as unique identifier - node_id = node.id - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=list(node.labels), - properties=dict(node), + if record: + # Handle nodes (compatible with multi-label cases) + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=list(node.labels), + properties=dict(node), + ) ) - ) - seen_nodes.add(node_id) + seen_nodes.add(node_id) - # Handle relationships (including direction information) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) ) - ) - seen_edges.add(edge_id) + seen_edges.add(edge_id) - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + finally: + await result_set.consume() # Ensure result set is consumed except neo4jExceptions.ClientError as e: - logger.error(f"APOC query failed: {str(e)}") - return await self._robust_fallback(label, max_depth) + logger.warning( + f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation" + ) + if inclusive: + logger.warning( + "Inclusive search mode is not supported in recursive query, using exact matching" + ) + return await self._robust_fallback(label, max_depth, min_degree) return result async def _robust_fallback( - self, label: str, max_depth: int + self, label: str, max_depth: int, min_degree: int = 0 ) -> Dict[str, List[Dict]]: - """Enhanced fallback query solution""" + """ + Fallback implementation when APOC plugin is not available or incompatible. + This method implements the same functionality as get_knowledge_graph but uses + only basic Cypher queries and recursive traversal instead of APOC procedures. + """ result = {"nodes": [], "edges": []} visited_nodes = set() visited_edges = set() async def traverse(current_label: str, current_depth: int): + # Check traversal limits if current_depth > max_depth: + logger.debug(f"Reached max depth: {max_depth}") + return + if len(visited_nodes) >= MAX_GRAPH_NODES: + logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") return # Get current node details @@ -614,46 +727,46 @@ class Neo4JStorage(BaseGraphStorage): return visited_nodes.add(node_id) - # Add node data (with complete labels) - node_data = {k: v for k, v in node.items()} - node_data["labels"] = [ - current_label - ] # Assume get_node method returns label information - result["nodes"].append(node_data) + # Add node data with label as ID + result["nodes"].append({ + "id": current_label, + "labels": current_label, + "properties": node + }) - # Get all outgoing and incoming edges + # Get connected nodes that meet the degree requirement + # Note: We don't need to check a's degree since it's the current node + # and was already validated in the previous iteration query = f""" - MATCH (a)-[r]-(b) - WHERE a:`{current_label}` OR b:`{current_label}` - RETURN a, r, b, - CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction + MATCH (a:`{current_label}`)-[r]-(b) + WITH r, b, + COUNT((b)--()) AS b_degree + WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) + RETURN r, b """ - async with self._driver.session(database=self._DATABASE) as session: - results = await session.run(query) + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + results = await session.run(query, {"min_degree": min_degree}) async for record in results: # Handle edges rel = record["r"] edge_id = f"{rel.id}_{rel.type}" if edge_id not in visited_edges: - edge_data = dict(rel) - edge_data.update( - { - "source": list(record["a"].labels)[0], - "target": list(record["b"].labels)[0], + b_node = record["b"] + if b_node.labels: # Only process if target node has labels + target_label = list(b_node.labels)[0] + result["edges"].append({ + "id": f"{current_label}_{target_label}", "type": rel.type, - "direction": record["direction"], - } - ) - result["edges"].append(edge_data) - visited_edges.add(edge_id) + "source": current_label, + "target": target_label, + "properties": dict(rel) + }) + visited_edges.add(edge_id) - # Recursively traverse adjacent nodes - next_label = ( - list(record["b"].labels)[0] - if record["direction"] == "OUTGOING" - else list(record["a"].labels)[0] - ) - await traverse(next_label, current_depth + 1) + # Continue traversal + await traverse(target_label, current_depth + 1) + else: + logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") await traverse(label, 0) return result @@ -664,7 +777,7 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" @@ -679,8 +792,11 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) labels = [] - async for record in result: - labels.append(record["label"]) + try: + async for record in result: + labels.append(record["label"]) + finally: + await result.consume() # Ensure results are consumed even if processing fails return labels @retry( @@ -763,7 +879,7 @@ class Neo4JStorage(BaseGraphStorage): async def _do_delete_edge(tx: AsyncManagedTransaction): query = f""" - MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) DELETE r """ await tx.run(query)