diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 3e77cc59..a531d7f0 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1424,8 +1424,8 @@ def create_app(args): # query all graph @app.get("/graphs") - async def get_graphs(label: str): - return await rag.get_graphs(nodel_label=label, max_depth=100) + async def get_knowledge_graph(label: str): + return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) # Add Ollama API routes ollama_api = OllamaAPI(rag) diff --git a/lightrag/base.py b/lightrag/base.py index 3702b49e..e75167c4 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -13,6 +13,7 @@ from typing import ( import numpy as np from .utils import EmbeddingFunc +from .types import KnowledgeGraph class TextChunkSchema(TypedDict): @@ -175,7 +176,7 @@ class BaseGraphStorage(StorageNameSpace): async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 - ) -> dict[str, list[dict]]: + ) -> KnowledgeGraph: raise NotImplementedError diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 79365c87..7845780d 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -25,6 +25,7 @@ from tenacity import ( from ..utils import logger from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge @dataclass @@ -44,7 +45,8 @@ class Neo4JStorage(BaseGraphStorage): URI = os.environ["NEO4J_URI"] USERNAME = os.environ["NEO4J_USERNAME"] PASSWORD = os.environ["NEO4J_PASSWORD"] - MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) + MAX_CONNECTION_POOL_SIZE = os.environ.get( + "NEO4J_MAX_CONNECTION_POOL_SIZE", 800) DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) ) @@ -74,19 +76,22 @@ class Neo4JStorage(BaseGraphStorage): ) raise e except neo4jExceptions.AuthError as e: - logger.error(f"Authentication failed for {database} at {URI}") + logger.error( + f"Authentication failed for {database} at {URI}") raise e except neo4jExceptions.ClientError as e: if e.code == "Neo.ClientError.Database.DatabaseNotFound": logger.info( - f"{database} at {URI} not found. Try to create specified database.".capitalize() + f"{database} at {URI} not found. Try to create specified database.".capitalize( + ) ) try: with _sync_driver.session() as session: session.run( f"CREATE DATABASE `{database}` IF NOT EXISTS" ) - logger.info(f"{database} at {URI} created".capitalize()) + logger.info( + f"{database} at {URI} created".capitalize()) connected = True except ( neo4jExceptions.ClientError, @@ -103,7 +108,8 @@ class Neo4JStorage(BaseGraphStorage): "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." ) if database is None: - logger.error(f"Failed to create {database} at {URI}") + logger.error( + f"Failed to create {database} at {URI}") raise e if connected: @@ -365,7 +371,7 @@ class Neo4JStorage(BaseGraphStorage): async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 - ) -> Dict[str, List[Dict]]: + ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -376,7 +382,7 @@ class Neo4JStorage(BaseGraphStorage): 4. Add depth control """ label = node_label.strip('"') - result = {"nodes": [], "edges": []} + result = KnowledgeGraph() seen_nodes = set() seen_edges = set() @@ -395,7 +401,8 @@ class Neo4JStorage(BaseGraphStorage): validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" validate_result = await session.run(validate_query) if not await validate_result.single(): - logger.warning(f"Starting node {label} does not exist!") + logger.warning( + f"Starting node {label} does not exist!") return result # Optimized query (including direction handling and self-loops) @@ -420,11 +427,11 @@ class Neo4JStorage(BaseGraphStorage): # Use node ID + label combination as unique identifier node_id = node.id if node_id not in seen_nodes: - node_data = {} - node_data["labels"] = list(node.labels) # Keep all labels - node_data["id"] = f"{node_id}" - node_data["properties"] = dict(node) - result["nodes"].append(node_data) + result.nodes.append(KnowledgeGraphNode( + id=f"{node_id}", + labels=list(node.labels), + properties=dict(node), + )) seen_nodes.add(node_id) # Handle relationships (including direction information) @@ -433,21 +440,17 @@ class Neo4JStorage(BaseGraphStorage): if edge_id not in seen_edges: start = rel.start_node end = rel.end_node - edge_data = {} - edge_data.update( - { - "source": f"{start.id}", - "target": f"{end.id}", - "type": rel.type, - "id": f"{edge_id}", - "properties": dict(rel), - } - ) - result["edges"].append(edge_data) + result.edges.append(KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + )) seen_edges.add(edge_id) logger.info( - f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}" + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) except neo4jExceptions.ClientError as e: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 157c6ef2..a6656c0b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -34,6 +34,7 @@ from .utils import ( logger, set_logger, ) +from .types import KnowledgeGraph STORAGES = { "NetworkXStorage": ".kg.networkx_impl", @@ -385,7 +386,7 @@ class LightRAG: text = await self.chunk_entity_relation_graph.get_all_labels() return text - async def get_graphs(self, nodel_label: str, max_depth: int): + async def get_knowledge_graph(self, nodel_label: str, max_depth: int) -> KnowledgeGraph: return await self.chunk_entity_relation_graph.get_knowledge_graph( node_label=nodel_label, max_depth=max_depth ) diff --git a/lightrag/types.py b/lightrag/types.py index 8190b2ea..9c8e0099 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,7 +1,26 @@ from pydantic import BaseModel -from typing import List +from typing import List, Dict, Any class GPTKeywordExtractionFormat(BaseModel): high_level_keywords: List[str] low_level_keywords: List[str] + + +class KnowledgeGraphNode(BaseModel): + id: str + labels: List[str] + properties: Dict[str, Any] # anything else goes here + + +class KnowledgeGraphEdge(BaseModel): + id: str + type: str + source: str # id of source node + target: str # id of target node + properties: Dict[str, Any] # anything else goes here + + +class KnowledgeGraph(BaseModel): + nodes: List[KnowledgeGraphNode] = [] + edges: List[KnowledgeGraphEdge] = []