From 68bf02abb6224ba1212af0ada9da54c23b0b3185 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 16:20:37 +0800 Subject: [PATCH] refactor: improve graph querying with label substring matching and security fixes --- lightrag/kg/neo4j_impl.py | 26 ++++++++++++++++---------- lightrag/kg/networkx_impl.py | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 2fb2c494..8052b1f7 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -34,7 +34,7 @@ from neo4j import ( # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") -# 从环境变量获取最大图节点数,默认为1000 +# Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) @@ -473,20 +473,22 @@ class Neo4JStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence + 1. Label matching nodes take precedence (nodes containing the specified label string) 2. Followed by nodes directly connected to the matching nodes 3. Finally, the degree of the nodes Args: - node_label (str): Label of the starting node + node_label (str): String to match in node labels (will match any node containing this string in its label) max_depth (int, optional): Maximum depth of the graph. Defaults to 5. Returns: KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') + # Escape single quotes to prevent injection attacks + escaped_label = label.replace("'", "\\'") result = KnowledgeGraph() seen_nodes = set() seen_edges = set() @@ -510,16 +512,20 @@ class Neo4JStorage(BaseGraphStorage): ) else: - # Critical debug step: first verify if starting node exists - validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" + validate_query = f""" + MATCH (n) + WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}') + RETURN n LIMIT 1 + """ validate_result = await session.run(validate_query) if not await validate_result.single(): - logger.warning(f"Starting node {label} does not exist!") + logger.warning(f"No nodes containing '{label}' in their labels found!") return result - # Optimized query (including direction handling and self-loops) + # Main query uses partial matching main_query = f""" - MATCH (start:`{label}`) + MATCH (start) + WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') WITH start CALL apoc.path.subgraphAll(start, {{ relationshipFilter: '>', @@ -598,7 +604,7 @@ class Neo4JStorage(BaseGraphStorage): result = {"nodes": [], "edges": []} visited_nodes = set() visited_edges = set() - + async def traverse(current_label: str, current_depth: int): if current_depth > max_depth: return diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 462fb832..92d36fa6 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: 1. Label matching nodes take precedence