Add max nodes limit for graph retrieval of networkX

• Set MAX_GRAPH_NODES env var (default 1000)
• Change edge type to "RELATED"
This commit is contained in:
yangdx
2025-03-02 12:52:25 +08:00
parent 7124845e55
commit 1ca6837219
3 changed files with 26 additions and 8 deletions

View File

@@ -3,6 +3,7 @@
# PORT=9621 # PORT=9621
# WORKERS=1 # WORKERS=1
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# MAX_GRAPH_NODES=1000 # Max nodes return from grap retrieval
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080 # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
### Optional SSL Configuration ### Optional SSL Configuration

View File

@@ -16,12 +16,27 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) @router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
async def get_graph_labels(): async def get_graph_labels():
"""Get all graph labels""" """
Get all graph labels
Returns:
List[str]: List of graph labels
"""
return await rag.get_graph_labels() return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)]) @router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3): async def get_knowledge_graph(label: str, max_depth: int = 3):
"""Get knowledge graph for a specific label""" """
Get knowledge graph for a specific label.
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
Returns:
Dict[str, List[str]]: Knowledge graph for label
"""
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth) return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
return router return router

View File

@@ -24,6 +24,8 @@ from .shared_storage import (
is_multiprocess, is_multiprocess,
) )
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final @final
@dataclass @dataclass
@@ -234,6 +236,7 @@ class NetworkXStorage(BaseGraphStorage):
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node
@@ -269,18 +272,17 @@ class NetworkXStorage(BaseGraphStorage):
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500 if len(subgraph.nodes()) > MAX_GRAPH_NODES:
if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes :MAX_GRAPH_NODES
] ]
top_node_ids = [node[0] for node in top_nodes] top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes # Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids) subgraph = subgraph.subgraph(top_node_ids)
logger.info( logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
) )
# Add nodes to result # Add nodes to result
@@ -320,7 +322,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type="DIRECTED", type="RELATED",
source=str(source), source=str(source),
target=str(target), target=str(target),
properties=edge_data, properties=edge_data,