refactor: improve graph querying with label substring matching and security fixes

This commit is contained in:
yangdx
2025-03-02 16:20:37 +08:00
parent 0f1eb42c8d
commit 68bf02abb6
2 changed files with 17 additions and 11 deletions

View File

@@ -34,7 +34,7 @@ from neo4j import ( # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# 从环境变量获取最大图节点数,默认为1000
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@@ -473,20 +473,22 @@ class Neo4JStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
1. Label matching nodes take precedence (nodes containing the specified label string)
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Args:
node_label (str): Label of the starting node
node_label (str): String to match in node labels (will match any node containing this string in its label)
max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
"""
label = node_label.strip('"')
# Escape single quotes to prevent injection attacks
escaped_label = label.replace("'", "\\'")
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
@@ -510,16 +512,20 @@ class Neo4JStorage(BaseGraphStorage):
)
else:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_query = f"""
MATCH (n)
WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}')
RETURN n LIMIT 1
"""
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
logger.warning(f"No nodes containing '{label}' in their labels found!")
return result
# Optimized query (including direction handling and self-loops)
# Main query uses partial matching
main_query = f"""
MATCH (start:`{label}`)
MATCH (start)
WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}')
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',

View File

@@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence