Add max nodes limit for graph retrieval of networkX

• Set MAX_GRAPH_NODES env var (default 1000)
• Change edge type to "RELATED"
This commit is contained in:
yangdx
2025-03-02 12:52:25 +08:00
parent 7124845e55
commit 1ca6837219
3 changed files with 26 additions and 8 deletions

View File

@@ -3,6 +3,7 @@
# PORT=9621
# WORKERS=1
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# MAX_GRAPH_NODES=1000 # Max nodes return from grap retrieval
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
### Optional SSL Configuration

View File

@@ -16,12 +16,27 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
async def get_graph_labels():
"""Get all graph labels"""
"""
Get all graph labels
Returns:
List[str]: List of graph labels
"""
return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3):
"""Get knowledge graph for a specific label"""
"""
Get knowledge graph for a specific label.
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
Returns:
Dict[str, List[str]]: Knowledge graph for label
"""
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
return router

View File

@@ -24,6 +24,8 @@ from .shared_storage import (
is_multiprocess,
)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final
@dataclass
@@ -234,6 +236,7 @@ class NetworkXStorage(BaseGraphStorage):
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args:
node_label: Label of the starting node
@@ -269,18 +272,17 @@ class NetworkXStorage(BaseGraphStorage):
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500
if len(subgraph.nodes()) > max_graph_nodes:
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes
:MAX_GRAPH_NODES
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes
# Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
)
# Add nodes to result
@@ -320,7 +322,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
type="RELATED",
source=str(source),
target=str(target),
properties=edge_data,