Update graph retrival api(abandon pydantic model)
This commit is contained in:
@@ -2,61 +2,32 @@
|
|||||||
This module contains all graph-related routes for the LightRAG API.
|
This module contains all graph-related routes for the LightRAG API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from ..utils_api import get_combined_auth_dependency
|
from ..utils_api import get_combined_auth_dependency
|
||||||
|
|
||||||
router = APIRouter(tags=["graph"])
|
router = APIRouter(tags=["graph"])
|
||||||
|
|
||||||
# Pydantic models for graph routes
|
|
||||||
class GraphLabelsResponse(BaseModel):
|
|
||||||
"""Response model: List of graph labels"""
|
|
||||||
labels: List[str] = Field(description="List of graph labels")
|
|
||||||
|
|
||||||
class KnowledgeGraphNode(BaseModel):
|
|
||||||
"""Model for a node in the knowledge graph"""
|
|
||||||
id: str = Field(description="Unique identifier of the node")
|
|
||||||
label: str = Field(description="Label of the node")
|
|
||||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the node")
|
|
||||||
|
|
||||||
class KnowledgeGraphEdge(BaseModel):
|
|
||||||
"""Model for an edge in the knowledge graph"""
|
|
||||||
source: str = Field(description="Source node ID")
|
|
||||||
target: str = Field(description="Target node ID")
|
|
||||||
type: str = Field(description="Type of the relationship")
|
|
||||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the edge")
|
|
||||||
|
|
||||||
class KnowledgeGraphResponse(BaseModel):
|
|
||||||
"""Response model: Knowledge graph data"""
|
|
||||||
nodes: List[KnowledgeGraphNode] = Field(description="List of nodes in the graph")
|
|
||||||
edges: List[KnowledgeGraphEdge] = Field(description="List of edges in the graph")
|
|
||||||
|
|
||||||
|
|
||||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||||
combined_auth = get_combined_auth_dependency(api_key)
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
|
|
||||||
@router.get("/graph/label/list",
|
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
|
||||||
dependencies=[Depends(combined_auth)],
|
|
||||||
response_model=GraphLabelsResponse)
|
|
||||||
async def get_graph_labels():
|
async def get_graph_labels():
|
||||||
"""
|
"""
|
||||||
Get all graph labels
|
Get all graph labels
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GraphLabelsResponse: List of graph labels
|
List[str]: List of graph labels
|
||||||
"""
|
"""
|
||||||
labels = await rag.get_graph_labels()
|
return await rag.get_graph_labels()
|
||||||
return GraphLabelsResponse(labels=labels)
|
|
||||||
|
|
||||||
@router.get("/graphs",
|
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
||||||
dependencies=[Depends(combined_auth)],
|
|
||||||
response_model=KnowledgeGraphResponse)
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
label: str = Query(..., description="Label to get knowledge graph for"),
|
label: str = Query(..., description="Label to get knowledge graph for"),
|
||||||
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
||||||
max_nodes: int = Query(1000, description="Maxiumu nodes to return", ge=1),
|
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||||
@@ -70,19 +41,12 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||||||
max_nodes: Maxiumu nodes to return
|
max_nodes: Maxiumu nodes to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraphResponse: Knowledge graph containing nodes and edges
|
Dict[str, List[str]]: Knowledge graph for label
|
||||||
"""
|
"""
|
||||||
graph_data = await rag.get_knowledge_graph(
|
return await rag.get_knowledge_graph(
|
||||||
node_label=label,
|
node_label=label,
|
||||||
max_depth=max_depth,
|
max_depth=max_depth,
|
||||||
max_nodes=max_nodes,
|
max_nodes=max_nodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert the returned dictionary to our response model format
|
|
||||||
# Assuming the returned dictionary has 'nodes' and 'edges' keys
|
|
||||||
return KnowledgeGraphResponse(
|
|
||||||
nodes=graph_data.get("nodes", []),
|
|
||||||
edges=graph_data.get("edges", [])
|
|
||||||
)
|
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
@@ -341,7 +341,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 3
|
self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||||
|
|
||||||
|
@@ -510,36 +510,20 @@ class LightRAG:
|
|||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
max_depth: int = 3,
|
max_depth: int = 3,
|
||||||
min_degree: int = 0,
|
max_nodes: int = 1000,
|
||||||
inclusive: bool = False,
|
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
"""Get knowledge graph for a given label
|
"""Get knowledge graph for a given label
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_label (str): Label to get knowledge graph for
|
node_label (str): Label to get knowledge graph for
|
||||||
max_depth (int): Maximum depth of graph
|
max_depth (int): Maximum depth of graph
|
||||||
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
|
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
|
||||||
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraph: Knowledge graph containing nodes and edges
|
KnowledgeGraph: Knowledge graph containing nodes and edges
|
||||||
"""
|
"""
|
||||||
# get params supported by get_knowledge_graph of specified storage
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
storage_params = inspect.signature(
|
return await self.chunk_entity_relation_graph.get_knowledge_graph(node_label, max_depth, max_nodes)
|
||||||
self.chunk_entity_relation_graph.get_knowledge_graph
|
|
||||||
).parameters
|
|
||||||
|
|
||||||
kwargs = {"node_label": node_label, "max_depth": max_depth}
|
|
||||||
|
|
||||||
if "min_degree" in storage_params and min_degree > 0:
|
|
||||||
kwargs["min_degree"] = min_degree
|
|
||||||
|
|
||||||
if "inclusive" in storage_params:
|
|
||||||
kwargs["inclusive"] = inclusive
|
|
||||||
|
|
||||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
|
|
||||||
|
|
||||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||||
import_path = STORAGES[storage_name]
|
import_path = STORAGES[storage_name]
|
||||||
|
Reference in New Issue
Block a user