Update graph retrival api
This commit is contained in:
@@ -2,53 +2,87 @@
|
||||
This module contains all graph-related routes for the LightRAG API.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..utils_api import get_combined_auth_dependency
|
||||
|
||||
router = APIRouter(tags=["graph"])
|
||||
|
||||
# Pydantic models for graph routes
|
||||
class GraphLabelsResponse(BaseModel):
|
||||
"""Response model: List of graph labels"""
|
||||
labels: List[str] = Field(description="List of graph labels")
|
||||
|
||||
class KnowledgeGraphNode(BaseModel):
|
||||
"""Model for a node in the knowledge graph"""
|
||||
id: str = Field(description="Unique identifier of the node")
|
||||
label: str = Field(description="Label of the node")
|
||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the node")
|
||||
|
||||
class KnowledgeGraphEdge(BaseModel):
|
||||
"""Model for an edge in the knowledge graph"""
|
||||
source: str = Field(description="Source node ID")
|
||||
target: str = Field(description="Target node ID")
|
||||
type: str = Field(description="Type of the relationship")
|
||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the edge")
|
||||
|
||||
class KnowledgeGraphResponse(BaseModel):
|
||||
"""Response model: Knowledge graph data"""
|
||||
nodes: List[KnowledgeGraphNode] = Field(description="List of nodes in the graph")
|
||||
edges: List[KnowledgeGraphEdge] = Field(description="List of edges in the graph")
|
||||
|
||||
|
||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
|
||||
@router.get("/graph/label/list",
|
||||
dependencies=[Depends(combined_auth)],
|
||||
response_model=GraphLabelsResponse)
|
||||
async def get_graph_labels():
|
||||
"""
|
||||
Get all graph labels
|
||||
|
||||
Returns:
|
||||
List[str]: List of graph labels
|
||||
GraphLabelsResponse: List of graph labels
|
||||
"""
|
||||
return await rag.get_graph_labels()
|
||||
labels = await rag.get_graph_labels()
|
||||
return GraphLabelsResponse(labels=labels)
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
||||
@router.get("/graphs",
|
||||
dependencies=[Depends(combined_auth)],
|
||||
response_model=KnowledgeGraphResponse)
|
||||
async def get_knowledge_graph(
|
||||
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||
label: str = Query(..., description="Label to get knowledge graph for"),
|
||||
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
||||
max_nodes: int = Query(1000, description="Maxiumu nodes to return", ge=1),
|
||||
):
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||
1. Hops(path) to the staring node take precedence
|
||||
2. Followed by the degree of the nodes
|
||||
|
||||
Args:
|
||||
label (str): Label to get knowledge graph for
|
||||
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
|
||||
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
|
||||
min_degree (int, optional): Minimum degree of nodes. Defaults to 0. (Deprecated, always 0)
|
||||
label (str): Label of the starting node
|
||||
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Knowledge graph for label
|
||||
KnowledgeGraphResponse: Knowledge graph containing nodes and edges
|
||||
"""
|
||||
return await rag.get_knowledge_graph(
|
||||
graph_data = await rag.get_knowledge_graph(
|
||||
node_label=label,
|
||||
max_depth=max_depth,
|
||||
max_nodes=max_nodes,
|
||||
)
|
||||
|
||||
# Convert the returned dictionary to our response model format
|
||||
# Assuming the returned dictionary has 'nodes' and 'edges' keys
|
||||
return KnowledgeGraphResponse(
|
||||
nodes=graph_data.get("nodes", []),
|
||||
edges=graph_data.get("edges", [])
|
||||
)
|
||||
|
||||
return router
|
||||
|
Reference in New Issue
Block a user