Update graph retrival api

This commit is contained in:
yangdx
2025-04-02 17:21:45 +08:00
parent fb2fad7766
commit d62a77500b

View File

@@ -2,53 +2,87 @@
This module contains all graph-related routes for the LightRAG API.
"""
from typing import Optional
from fastapi import APIRouter, Depends
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel, Field
from ..utils_api import get_combined_auth_dependency
router = APIRouter(tags=["graph"])
# Pydantic models for graph routes
class GraphLabelsResponse(BaseModel):
"""Response model: List of graph labels"""
labels: List[str] = Field(description="List of graph labels")
class KnowledgeGraphNode(BaseModel):
"""Model for a node in the knowledge graph"""
id: str = Field(description="Unique identifier of the node")
label: str = Field(description="Label of the node")
properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the node")
class KnowledgeGraphEdge(BaseModel):
"""Model for an edge in the knowledge graph"""
source: str = Field(description="Source node ID")
target: str = Field(description="Target node ID")
type: str = Field(description="Type of the relationship")
properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the edge")
class KnowledgeGraphResponse(BaseModel):
"""Response model: Knowledge graph data"""
nodes: List[KnowledgeGraphNode] = Field(description="List of nodes in the graph")
edges: List[KnowledgeGraphEdge] = Field(description="List of edges in the graph")
def create_graph_routes(rag, api_key: Optional[str] = None):
combined_auth = get_combined_auth_dependency(api_key)
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
@router.get("/graph/label/list",
dependencies=[Depends(combined_auth)],
response_model=GraphLabelsResponse)
async def get_graph_labels():
"""
Get all graph labels
Returns:
List[str]: List of graph labels
GraphLabelsResponse: List of graph labels
"""
return await rag.get_graph_labels()
labels = await rag.get_graph_labels()
return GraphLabelsResponse(labels=labels)
@router.get("/graphs", dependencies=[Depends(combined_auth)])
@router.get("/graphs",
dependencies=[Depends(combined_auth)],
response_model=KnowledgeGraphResponse)
async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
label: str = Query(..., description="Label to get knowledge graph for"),
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
max_nodes: int = Query(1000, description="Maxiumu nodes to return", ge=1),
):
"""
Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
1. Hops(path) to the staring node take precedence
2. Followed by the degree of the nodes
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
min_degree (int, optional): Minimum degree of nodes. Defaults to 0. Deprecated, always 0
label (str): Label of the starting node
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
max_nodes: Maxiumu nodes to return
Returns:
Dict[str, List[str]]: Knowledge graph for label
KnowledgeGraphResponse: Knowledge graph containing nodes and edges
"""
return await rag.get_knowledge_graph(
graph_data = await rag.get_knowledge_graph(
node_label=label,
max_depth=max_depth,
max_nodes=max_nodes,
)
# Convert the returned dictionary to our response model format
# Assuming the returned dictionary has 'nodes' and 'edges' keys
return KnowledgeGraphResponse(
nodes=graph_data.get("nodes", []),
edges=graph_data.get("edges", [])
)
return router