diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index 41046562..381df90b 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -2,61 +2,32 @@ This module contains all graph-related routes for the LightRAG API. """ -from typing import Optional, List, Dict, Any +from typing import Optional from fastapi import APIRouter, Depends, Query -from pydantic import BaseModel, Field from ..utils_api import get_combined_auth_dependency router = APIRouter(tags=["graph"]) -# Pydantic models for graph routes -class GraphLabelsResponse(BaseModel): - """Response model: List of graph labels""" - labels: List[str] = Field(description="List of graph labels") - -class KnowledgeGraphNode(BaseModel): - """Model for a node in the knowledge graph""" - id: str = Field(description="Unique identifier of the node") - label: str = Field(description="Label of the node") - properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the node") - -class KnowledgeGraphEdge(BaseModel): - """Model for an edge in the knowledge graph""" - source: str = Field(description="Source node ID") - target: str = Field(description="Target node ID") - type: str = Field(description="Type of the relationship") - properties: Dict[str, Any] = Field(default_factory=dict, description="Properties of the edge") - -class KnowledgeGraphResponse(BaseModel): - """Response model: Knowledge graph data""" - nodes: List[KnowledgeGraphNode] = Field(description="List of nodes in the graph") - edges: List[KnowledgeGraphEdge] = Field(description="List of edges in the graph") - def create_graph_routes(rag, api_key: Optional[str] = None): combined_auth = get_combined_auth_dependency(api_key) - @router.get("/graph/label/list", - dependencies=[Depends(combined_auth)], - response_model=GraphLabelsResponse) + @router.get("/graph/label/list", dependencies=[Depends(combined_auth)]) async def get_graph_labels(): """ Get all graph labels Returns: - GraphLabelsResponse: List of graph labels + List[str]: List of graph labels """ - labels = await rag.get_graph_labels() - return GraphLabelsResponse(labels=labels) + return await rag.get_graph_labels() - @router.get("/graphs", - dependencies=[Depends(combined_auth)], - response_model=KnowledgeGraphResponse) + @router.get("/graphs", dependencies=[Depends(combined_auth)]) async def get_knowledge_graph( label: str = Query(..., description="Label to get knowledge graph for"), max_depth: int = Query(3, description="Maximum depth of graph", ge=1), - max_nodes: int = Query(1000, description="Maxiumu nodes to return", ge=1), + max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), ): """ Retrieve a connected subgraph of nodes where the label includes the specified label. @@ -70,19 +41,12 @@ def create_graph_routes(rag, api_key: Optional[str] = None): max_nodes: Maxiumu nodes to return Returns: - KnowledgeGraphResponse: Knowledge graph containing nodes and edges + Dict[str, List[str]]: Knowledge graph for label """ - graph_data = await rag.get_knowledge_graph( + return await rag.get_knowledge_graph( node_label=label, max_depth=max_depth, max_nodes=max_nodes, ) - - # Convert the returned dictionary to our response model format - # Assuming the returned dictionary has 'nodes' and 'edges' keys - return KnowledgeGraphResponse( - nodes=graph_data.get("nodes", []), - edges=graph_data.get("edges", []) - ) return router diff --git a/lightrag/base.py b/lightrag/base.py index 223cc7c9..ec7ba9fa 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -341,7 +341,7 @@ class BaseGraphStorage(StorageNameSpace, ABC): @abstractmethod async def get_knowledge_graph( - self, node_label: str, max_depth: int = 3 + self, node_label: str, max_depth: int = 3, max_nodes: int = 1000 ) -> KnowledgeGraph: """Retrieve a subgraph of the knowledge graph starting from a given node.""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7077f94d..201bd6bf 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -510,36 +510,20 @@ class LightRAG: self, node_label: str, max_depth: int = 3, - min_degree: int = 0, - inclusive: bool = False, + max_nodes: int = 1000, ) -> KnowledgeGraph: """Get knowledge graph for a given label Args: node_label (str): Label to get knowledge graph for max_depth (int): Maximum depth of graph - min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0. - inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False. + max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000. Returns: KnowledgeGraph: Knowledge graph containing nodes and edges """ - # get params supported by get_knowledge_graph of specified storage - import inspect - storage_params = inspect.signature( - self.chunk_entity_relation_graph.get_knowledge_graph - ).parameters - - kwargs = {"node_label": node_label, "max_depth": max_depth} - - if "min_degree" in storage_params and min_degree > 0: - kwargs["min_degree"] = min_degree - - if "inclusive" in storage_params: - kwargs["inclusive"] = inclusive - - return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs) + return await self.chunk_entity_relation_graph.get_knowledge_graph(node_label, max_depth, max_nodes) def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name]