Add is_truncated to graph query for NetworkX graph db

This commit is contained in:
yangdx
2025-04-02 22:12:20 +08:00
parent 4ceafb7cbc
commit 82c4baba70
4 changed files with 33 additions and 12 deletions

View File

@@ -343,7 +343,18 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 3, max_nodes: int = 1000 self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node.""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
class DocStatus(str, Enum): class DocStatus(str, Enum):

View File

@@ -651,17 +651,14 @@ class Neo4JStorage(BaseGraphStorage):
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
When reducing the number of nodes, the prioritization criteria are as follows:
1. Hops(path) to the staring node take precedence
2. Followed by the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3 max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns: Returns:
KnowledgeGraph: Complete connected subgraph for specified node KnowledgeGraph object containing nodes and edges
""" """
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()

View File

@@ -270,16 +270,24 @@ class NetworkXStorage(BaseGraphStorage):
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
""" """
graph = await self._get_graph() graph = await self._get_graph()
result = KnowledgeGraph()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# Get degrees of all nodes # Get degrees of all nodes
degrees = dict(graph.degree()) degrees = dict(graph.degree())
# Sort nodes by degree in descending order and take top max_nodes # Sort nodes by degree in descending order and take top max_nodes
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True) sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
# Check if graph is truncated
if len(sorted_nodes) > max_nodes:
result.is_truncated = True
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]] limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
# Create subgraph with the highest degree nodes # Create subgraph with the highest degree nodes
subgraph = graph.subgraph(limited_nodes) subgraph = graph.subgraph(limited_nodes)
@@ -305,11 +313,15 @@ class NetworkXStorage(BaseGraphStorage):
neighbors = list(graph.neighbors(current)) neighbors = list(graph.neighbors(current))
queue.extend([n for n in neighbors if n not in visited]) queue.extend([n for n in neighbors if n not in visited])
# Check if graph is truncated - if we still have nodes in the queue
# and we've reached max_nodes, then the graph is truncated
if queue and len(bfs_nodes) >= max_nodes:
result.is_truncated = True
# Create subgraph with BFS discovered nodes # Create subgraph with BFS discovered nodes
subgraph = graph.subgraph(bfs_nodes) subgraph = graph.subgraph(bfs_nodes)
# Add nodes to result # Add nodes to result
result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
for node in subgraph.nodes(): for node in subgraph.nodes():

View File

@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
nodes: list[KnowledgeGraphNode] = [] nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = [] edges: list[KnowledgeGraphEdge] = []
is_truncated: bool = False