diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 58931eec..96315b82 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1683,6 +1683,10 @@ def create_app(args): raise HTTPException(status_code=500, detail=str(e)) # query all graph + @app.get("/graphs") + async def get_knowledge_graph(label: str): + return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + # Add Ollama API routes ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") diff --git a/lightrag/base.py b/lightrag/base.py index 5f6f8850..af060435 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -13,6 +13,7 @@ from typing import ( ) import numpy as np from .utils import EmbeddingFunc +from .types import KnowledgeGraph load_dotenv() @@ -197,6 +198,12 @@ class BaseGraphStorage(StorageNameSpace, ABC): ) -> tuple[np.ndarray[Any, Any], list[str]]: """Get all labels in the graph.""" + @abstractmethod + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """Retrieve a subgraph of the knowledge graph starting from a given node.""" + class DocStatus(str, Enum): """Document processing status""" diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 583423bb..077c7321 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np import pipmaster as pm +from lightrag.types import KnowledgeGraph from tenacity import ( retry, @@ -615,6 +616,11 @@ class AGEStorage(BaseGraphStorage): ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + async def index_done_callback(self) -> None: # AGES handles persistence automatically pass diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 45bc1fab..39077b5f 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -16,6 +16,7 @@ from tenacity import ( wait_exponential, ) +from lightrag.types import KnowledgeGraph from lightrag.utils import logger from ..base import BaseGraphStorage @@ -401,3 +402,8 @@ class GremlinStorage(BaseGraphStorage): self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index cfae4abd..07b48f8b 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -16,6 +16,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm if not pm.is_installed("pymongo"): @@ -598,6 +599,179 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # QUERY # ------------------------------------------------------------------------- + # + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """ + Get complete connected subgraph for specified node (including the starting node itself) + + Args: + node_label: Label of the nodes to start from + max_depth: Maximum depth of traversal (default: 5) + + Returns: + KnowledgeGraph object containing nodes and edges of the subgraph + """ + label = node_label + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + try: + if label == "*": + # Get all nodes and edges + async for node_doc in self.collection.find({}): + node_id = str(node_doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_doc.get("_id")], + properties={ + k: v + for k, v in node_doc.items() + if k not in ["_id", "edges"] + }, + ) + ) + seen_nodes.add(node_id) + + # Process edges + for edge in node_doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + else: + # Verify if starting node exists + start_nodes = self.collection.find({"_id": label}) + start_nodes_exist = await start_nodes.to_list(length=1) + if not start_nodes_exist: + logger.warning(f"Starting node with label {label} does not exist!") + return result + + # Use $graphLookup for traversal + pipeline = [ + { + "$match": {"_id": label} + }, # Start with nodes having the specified label + { + "$graphLookup": { + "from": self._collection_name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_nodes", + } + }, + ] + + async for doc in self.collection.aggregate(pipeline): + # Add the start node + node_id = str(doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[ + doc.get( + "_id", + ) + ], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "edges", + "connected_nodes", + "depth", + ] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from start node + for edge in doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + # Add connected nodes and their edges + for connected in doc.get("connected_nodes", []): + node_id = str(connected["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[connected.get("_id")], + properties={ + k: v + for k, v in connected.items() + if k not in ["_id", "edges", "depth"] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from connected nodes + for edge in connected.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except PyMongoError as e: + logger.error(f"MongoDB query failed: {str(e)}") + + return result async def index_done_callback(self) -> None: # Mongo handles persistence automatically diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 9754ffc5..de0273ad 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -17,6 +17,7 @@ from tenacity import ( from ..utils import logger from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm if not pm.is_installed("neo4j"): @@ -468,6 +469,99 @@ class Neo4JStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """ + Get complete connected subgraph for specified node (including the starting node itself) + + Key fixes: + 1. Include the starting node itself + 2. Handle multi-label nodes + 3. Clarify relationship directions + 4. Add depth control + """ + label = node_label.strip('"') + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + async with self._driver.session(database=self._DATABASE) as session: + try: + main_query = "" + if label == "*": + main_query = """ + MATCH (n) + WITH collect(DISTINCT n) AS nodes + MATCH ()-[r]-() + RETURN nodes, collect(DISTINCT r) AS relationships; + """ + else: + # Critical debug step: first verify if starting node exists + validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" + validate_result = await session.run(validate_query) + if not await validate_result.single(): + logger.warning(f"Starting node {label} does not exist!") + return result + + # Optimized query (including direction handling and self-loops) + main_query = f""" + MATCH (start:`{label}`) + WITH start + CALL apoc.path.subgraphAll(start, {{ + relationshipFilter: '>', + minLevel: 0, + maxLevel: {max_depth}, + bfs: true + }}) + YIELD nodes, relationships + RETURN nodes, relationships + """ + result_set = await session.run(main_query) + record = await result_set.single() + + if record: + # Handle nodes (compatible with multi-label cases) + for node in record["nodes"]: + # Use node ID + label combination as unique identifier + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=list(node.labels), + properties=dict(node), + ) + ) + seen_nodes.add(node_id) + + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except neo4jExceptions.ClientError as e: + logger.error(f"APOC query failed: {str(e)}") + return await self._robust_fallback(label, max_depth) + + return result + async def _robust_fallback( self, label: str, max_depth: int ) -> Dict[str, List[Dict]]: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 1874719f..3e7a08fd 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -5,6 +5,7 @@ from typing import Any, final import numpy as np +from lightrag.types import KnowledgeGraph from lightrag.utils import ( logger, ) @@ -166,3 +167,8 @@ class NetworkXStorage(BaseGraphStorage): for source, target in edges: if self._graph.has_edge(source, target): self._graph.remove_edge(source, target) + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 35983ad3..d65688da 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -8,6 +8,7 @@ from typing import Any, Union, final import numpy as np import configparser +from lightrag.types import KnowledgeGraph from ..base import ( BaseGraphStorage, @@ -669,6 +670,11 @@ class OracleGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index d7ace41a..a0e0f184 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -7,6 +7,7 @@ from typing import Any, Union, final import numpy as np import configparser +from lightrag.types import KnowledgeGraph import sys from tenacity import ( @@ -1084,6 +1085,11 @@ class PGGraphStorage(BaseGraphStorage): ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + async def drop(self) -> None: """Drop the storage""" drop_sql = SQL_TEMPLATES["drop_vdb_entity"] diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 2feb782a..7ba2cf66 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -5,6 +5,8 @@ from typing import Any, Union, final import numpy as np +from lightrag.types import KnowledgeGraph + from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..namespace import NameSpace, is_namespace @@ -558,6 +560,11 @@ class TiDBGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 71784a8b..0ba34ef7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -47,6 +47,7 @@ from .utils import ( set_logger, encode_string_by_tiktoken, ) +from .types import KnowledgeGraph # TODO: TO REMOVE @Yannick config = configparser.ConfigParser() @@ -457,6 +458,13 @@ class LightRAG: self._storages_status = StoragesStatus.FINALIZED logger.debug("Finalized Storages") + async def get_knowledge_graph( + self, nodel_label: str, max_depth: int + ) -> KnowledgeGraph: + return await self.chunk_entity_relation_graph.get_knowledge_graph( + node_label=nodel_label, max_depth=max_depth + ) + def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name)