diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f5c2237a..2fb2c494 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -23,7 +23,7 @@ import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") -from neo4j import ( +from neo4j import ( # type: ignore AsyncGraphDatabase, exceptions as neo4jExceptions, AsyncDriver, @@ -34,6 +34,9 @@ from neo4j import ( config = configparser.ConfigParser() config.read("config.ini", "utf-8") +# 从环境变量获取最大图节点数,默认为1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + @final @dataclass @@ -471,12 +474,17 @@ class Neo4JStorage(BaseGraphStorage): ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes - Key fixes: - 1. Include the starting node itself - 2. Handle multi-label nodes - 3. Clarify relationship directions - 4. Add depth control + Args: + node_label (str): Label of the starting node + max_depth (int, optional): Maximum depth of the graph. Defaults to 5. + Returns: + KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') result = KnowledgeGraph() @@ -485,14 +493,22 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session(database=self._DATABASE) as session: try: - main_query = "" if label == "*": main_query = """ MATCH (n) - WITH collect(DISTINCT n) AS nodes - MATCH ()-[r]-() - RETURN nodes, collect(DISTINCT r) AS relationships; + OPTIONAL MATCH (n)-[r]-() + WITH n, count(r) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect(n) AS nodes + MATCH (a)-[r]->(b) + WHERE a IN nodes AND b IN nodes + RETURN nodes, collect(DISTINCT r) AS relationships """ + result_set = await session.run( + main_query, {"max_nodes": MAX_GRAPH_NODES} + ) + else: # Critical debug step: first verify if starting node exists validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" @@ -512,9 +528,25 @@ class Neo4JStorage(BaseGraphStorage): bfs: true }}) YIELD nodes, relationships - RETURN nodes, relationships + WITH start, nodes, relationships + UNWIND nodes AS node + OPTIONAL MATCH (node)-[r]-() + WITH node, count(r) AS degree, start, nodes, relationships, + CASE + WHEN id(node) = id(start) THEN 2 + WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 + ELSE 0 + END AS priority + ORDER BY priority DESC, degree DESC + LIMIT $max_nodes + WITH collect(node) AS filtered_nodes, nodes, relationships + RETURN filtered_nodes AS nodes, + [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships """ - result_set = await session.run(main_query) + result_set = await session.run( + main_query, {"max_nodes": MAX_GRAPH_NODES} + ) + record = await result_set.single() if record: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index b1cc45fe..462fb832 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -236,7 +236,11 @@ class NetworkXStorage(BaseGraphStorage): ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) - Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes Args: node_label: Label of the starting node @@ -268,14 +272,49 @@ class NetworkXStorage(BaseGraphStorage): logger.warning(f"No nodes found with label {node_label}") return result - # Get subgraph using ego_graph - subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) + # Get subgraph using ego_graph from all matching nodes + combined_subgraph = nx.Graph() + for start_node in nodes_to_explore: + node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) + combined_subgraph = nx.compose(combined_subgraph, node_subgraph) + subgraph = combined_subgraph # Check if number of nodes exceeds max_graph_nodes if len(subgraph.nodes()) > MAX_GRAPH_NODES: origin_nodes = len(subgraph.nodes()) + + # 获取节点度数 node_degrees = dict(subgraph.degree()) - top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ + + # 标记起点节点和直接连接的节点 + start_nodes = set() + direct_connected_nodes = set() + + if node_label != "*" and nodes_to_explore: + # 所有在 nodes_to_explore 中的节点都是起点节点 + start_nodes = set(nodes_to_explore) + + # 获取与所有起点直接连接的节点 + for start_node in start_nodes: + direct_connected_nodes.update(subgraph.neighbors(start_node)) + + # 从直接连接节点中移除起点节点(避免重复) + direct_connected_nodes -= start_nodes + + # 按优先级和度数排序 + def priority_key(node_item): + node, degree = node_item + # 优先级排序:起点(2) > 直接连接(1) > 其他节点(0) + if node in start_nodes: + priority = 2 + elif node in direct_connected_nodes: + priority = 1 + else: + priority = 0 + return (priority, degree) # 先按优先级,再按度数 + + # 排序并选择前MAX_GRAPH_NODES个节点 + top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[ :MAX_GRAPH_NODES ] top_node_ids = [node[0] for node in top_nodes]