diff --git a/env.example b/env.example index 112676c6..99909ac6 100644 --- a/env.example +++ b/env.example @@ -5,6 +5,7 @@ # PORT=9621 # WORKERS=1 # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances +# MAX_GRAPH_NODES=1000 # Max nodes return from grap retrieval # CORS_ORIGINS=http://localhost:3000,http://localhost:8080 ### Optional SSL Configuration diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index 95a72758..e6f894a2 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -16,12 +16,32 @@ def create_graph_routes(rag, api_key: Optional[str] = None): @router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) async def get_graph_labels(): - """Get all graph labels""" + """ + Get all graph labels + + Returns: + List[str]: List of graph labels + """ return await rag.get_graph_labels() @router.get("/graphs", dependencies=[Depends(optional_api_key)]) async def get_knowledge_graph(label: str, max_depth: int = 3): - """Get knowledge graph for a specific label""" + """ + Retrieve a connected subgraph of nodes where the label includes the specified label. + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes + Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) + + Args: + label (str): Label to get knowledge graph for + max_depth (int, optional): Maximum depth of graph. Defaults to 3. + + Returns: + Dict[str, List[str]]: Knowledge graph for label + """ return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth) return router diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f5c2237a..dccee330 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -23,7 +23,7 @@ import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") -from neo4j import ( +from neo4j import ( # type: ignore AsyncGraphDatabase, exceptions as neo4jExceptions, AsyncDriver, @@ -34,6 +34,9 @@ from neo4j import ( config = configparser.ConfigParser() config.read("config.ini", "utf-8") +# Get maximum number of graph nodes from environment variable, default is 1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + @final @dataclass @@ -470,40 +473,61 @@ class Neo4JStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence (nodes containing the specified label string) + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes - Key fixes: - 1. Include the starting node itself - 2. Handle multi-label nodes - 3. Clarify relationship directions - 4. Add depth control + Args: + node_label (str): String to match in node labels (will match any node containing this string in its label) + max_depth (int, optional): Maximum depth of the graph. Defaults to 5. + Returns: + KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') + # Escape single quotes to prevent injection attacks + escaped_label = label.replace("'", "\\'") result = KnowledgeGraph() seen_nodes = set() seen_edges = set() async with self._driver.session(database=self._DATABASE) as session: try: - main_query = "" if label == "*": main_query = """ MATCH (n) - WITH collect(DISTINCT n) AS nodes - MATCH ()-[r]-() - RETURN nodes, collect(DISTINCT r) AS relationships; + OPTIONAL MATCH (n)-[r]-() + WITH n, count(r) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect(n) AS nodes + MATCH (a)-[r]->(b) + WHERE a IN nodes AND b IN nodes + RETURN nodes, collect(DISTINCT r) AS relationships """ + result_set = await session.run( + main_query, {"max_nodes": MAX_GRAPH_NODES} + ) + else: - # Critical debug step: first verify if starting node exists - validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" + validate_query = f""" + MATCH (n) + WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}') + RETURN n LIMIT 1 + """ validate_result = await session.run(validate_query) if not await validate_result.single(): - logger.warning(f"Starting node {label} does not exist!") + logger.warning( + f"No nodes containing '{label}' in their labels found!" + ) return result - # Optimized query (including direction handling and self-loops) + # Main query uses partial matching main_query = f""" - MATCH (start:`{label}`) + MATCH (start) + WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') WITH start CALL apoc.path.subgraphAll(start, {{ relationshipFilter: '>', @@ -512,9 +536,25 @@ class Neo4JStorage(BaseGraphStorage): bfs: true }}) YIELD nodes, relationships - RETURN nodes, relationships + WITH start, nodes, relationships + UNWIND nodes AS node + OPTIONAL MATCH (node)-[r]-() + WITH node, count(r) AS degree, start, nodes, relationships, + CASE + WHEN id(node) = id(start) THEN 2 + WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 + ELSE 0 + END AS priority + ORDER BY priority DESC, degree DESC + LIMIT $max_nodes + WITH collect(node) AS filtered_nodes, nodes, relationships + RETURN filtered_nodes AS nodes, + [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships """ - result_set = await session.run(main_query) + result_set = await session.run( + main_query, {"max_nodes": MAX_GRAPH_NODES} + ) + record = await result_set.single() if record: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index f11e9c0e..563fc554 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -24,6 +24,8 @@ from .shared_storage import ( is_multiprocess, ) +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + @final @dataclass @@ -233,7 +235,12 @@ class NetworkXStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes Args: node_label: Label of the starting node @@ -265,22 +272,51 @@ class NetworkXStorage(BaseGraphStorage): logger.warning(f"No nodes found with label {node_label}") return result - # Get subgraph using ego_graph - subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) + # Get subgraph using ego_graph from all matching nodes + combined_subgraph = nx.Graph() + for start_node in nodes_to_explore: + node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) + combined_subgraph = nx.compose(combined_subgraph, node_subgraph) + subgraph = combined_subgraph # Check if number of nodes exceeds max_graph_nodes - max_graph_nodes = 500 - if len(subgraph.nodes()) > max_graph_nodes: + if len(subgraph.nodes()) > MAX_GRAPH_NODES: origin_nodes = len(subgraph.nodes()) + node_degrees = dict(subgraph.degree()) - top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ - :max_graph_nodes + + start_nodes = set() + direct_connected_nodes = set() + + if node_label != "*" and nodes_to_explore: + start_nodes = set(nodes_to_explore) + # Get nodes directly connected to all start nodes + for start_node in start_nodes: + direct_connected_nodes.update(subgraph.neighbors(start_node)) + + # Remove start nodes from directly connected nodes (avoid duplicates) + direct_connected_nodes -= start_nodes + + def priority_key(node_item): + node, degree = node_item + # Priority order: start(2) > directly connected(1) > other nodes(0) + if node in start_nodes: + priority = 2 + elif node in direct_connected_nodes: + priority = 1 + else: + priority = 0 + return (priority, degree) + + # Sort by priority and degree and select top MAX_GRAPH_NODES nodes + top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[ + :MAX_GRAPH_NODES ] top_node_ids = [node[0] for node in top_nodes] - # Create new subgraph with only top nodes + # Create new subgraph and keep nodes only with most degree subgraph = subgraph.subgraph(top_node_ids) logger.info( - f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" + f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})" ) # Add nodes to result @@ -320,7 +356,7 @@ class NetworkXStorage(BaseGraphStorage): result.edges.append( KnowledgeGraphEdge( id=edge_id, - type="DIRECTED", + type="RELATED", source=str(source), target=str(target), properties=edge_data, diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a2d806b6..a5d3c94b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1173,7 +1173,7 @@ class LightRAG: """ if param.mode in ["local", "global", "hybrid"]: response = await kg_query( - query, + query.strip(), self.chunk_entity_relation_graph, self.entities_vdb, self.relationships_vdb, @@ -1194,7 +1194,7 @@ class LightRAG: ) elif param.mode == "naive": response = await naive_query( - query, + query.strip(), self.chunks_vdb, self.text_chunks, param, @@ -1213,7 +1213,7 @@ class LightRAG: ) elif param.mode == "mix": response = await mix_kg_vector_query( - query, + query.strip(), self.chunk_entity_relation_graph, self.entities_vdb, self.relationships_vdb,