Add node limit and prioritization for knowledge graph retrieval

• Add MAX_GRAPH_NODES limit from env var
• Prioritize nodes by label match & connection
This commit is contained in:
yangdx
2025-03-02 15:39:14 +08:00
parent 87d0ee0127
commit 0f1eb42c8d
2 changed files with 87 additions and 16 deletions

View File

@@ -23,7 +23,7 @@ import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
pm.install("neo4j") pm.install("neo4j")
from neo4j import ( from neo4j import ( # type: ignore
AsyncGraphDatabase, AsyncGraphDatabase,
exceptions as neo4jExceptions, exceptions as neo4jExceptions,
AsyncDriver, AsyncDriver,
@@ -34,6 +34,9 @@ from neo4j import (
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
# 从环境变量获取最大图节点数默认为1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final @final
@dataclass @dataclass
@@ -471,12 +474,17 @@ class Neo4JStorage(BaseGraphStorage):
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Key fixes: Args:
1. Include the starting node itself node_label (str): Label of the starting node
2. Handle multi-label nodes max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
3. Clarify relationship directions Returns:
4. Add depth control KnowledgeGraph: Complete connected subgraph for specified node
""" """
label = node_label.strip('"') label = node_label.strip('"')
result = KnowledgeGraph() result = KnowledgeGraph()
@@ -485,14 +493,22 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
try: try:
main_query = ""
if label == "*": if label == "*":
main_query = """ main_query = """
MATCH (n) MATCH (n)
WITH collect(DISTINCT n) AS nodes OPTIONAL MATCH (n)-[r]-()
MATCH ()-[r]-() WITH n, count(r) AS degree
RETURN nodes, collect(DISTINCT r) AS relationships; ORDER BY degree DESC
LIMIT $max_nodes
WITH collect(n) AS nodes
MATCH (a)-[r]->(b)
WHERE a IN nodes AND b IN nodes
RETURN nodes, collect(DISTINCT r) AS relationships
""" """
result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES}
)
else: else:
# Critical debug step: first verify if starting node exists # Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
@@ -512,9 +528,25 @@ class Neo4JStorage(BaseGraphStorage):
bfs: true bfs: true
}}) }})
YIELD nodes, relationships YIELD nodes, relationships
RETURN nodes, relationships WITH start, nodes, relationships
UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-()
WITH node, count(r) AS degree, start, nodes, relationships,
CASE
WHEN id(node) = id(start) THEN 2
WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
ELSE 0
END AS priority
ORDER BY priority DESC, degree DESC
LIMIT $max_nodes
WITH collect(node) AS filtered_nodes, nodes, relationships
RETURN filtered_nodes AS nodes,
[rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
""" """
result_set = await session.run(main_query) result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES}
)
record = await result_set.single() record = await result_set.single()
if record: if record:

View File

@@ -236,7 +236,11 @@ class NetworkXStorage(BaseGraphStorage):
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node
@@ -268,14 +272,49 @@ class NetworkXStorage(BaseGraphStorage):
logger.warning(f"No nodes found with label {node_label}") logger.warning(f"No nodes found with label {node_label}")
return result return result
# Get subgraph using ego_graph # Get subgraph using ego_graph from all matching nodes
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) combined_subgraph = nx.Graph()
for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
subgraph = combined_subgraph
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES: if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
# 获取节点度数
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
# 标记起点节点和直接连接的节点
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
# 所有在 nodes_to_explore 中的节点都是起点节点
start_nodes = set(nodes_to_explore)
# 获取与所有起点直接连接的节点
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# 从直接连接节点中移除起点节点(避免重复)
direct_connected_nodes -= start_nodes
# 按优先级和度数排序
def priority_key(node_item):
node, degree = node_item
# 优先级排序:起点(2) > 直接连接(1) > 其他节点(0)
if node in start_nodes:
priority = 2
elif node in direct_connected_nodes:
priority = 1
else:
priority = 0
return (priority, degree) # 先按优先级,再按度数
# 排序并选择前MAX_GRAPH_NODES个节点
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
:MAX_GRAPH_NODES :MAX_GRAPH_NODES
] ]
top_node_ids = [node[0] for node in top_nodes] top_node_ids = [node[0] for node in top_nodes]