Merge pull request #901 from HKUDS/revert-886-clean-2

Revert "removed get_knowledge_graph"
This commit is contained in:
Yannick Stephan
2025-02-20 14:31:00 +01:00
committed by GitHub
11 changed files with 324 additions and 0 deletions

View File

@@ -1683,6 +1683,10 @@ def create_app(args):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# query all graph # query all graph
@app.get("/graphs")
async def get_knowledge_graph(label: str):
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
# Add Ollama API routes # Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k) ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")

View File

@@ -13,6 +13,7 @@ from typing import (
) )
import numpy as np import numpy as np
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph
load_dotenv() load_dotenv()
@@ -197,6 +198,12 @@ class BaseGraphStorage(StorageNameSpace, ABC):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
"""Get all labels in the graph.""" """Get all labels in the graph."""
@abstractmethod
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
class DocStatus(str, Enum): class DocStatus(str, Enum):
"""Document processing status""" """Document processing status"""

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union, final from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import numpy as np import numpy as np
import pipmaster as pm import pipmaster as pm
from lightrag.types import KnowledgeGraph
from tenacity import ( from tenacity import (
retry, retry,
@@ -615,6 +616,11 @@ class AGEStorage(BaseGraphStorage):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# AGES handles persistence automatically # AGES handles persistence automatically
pass pass

View File

@@ -16,6 +16,7 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from lightrag.types import KnowledgeGraph
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
@@ -401,3 +402,8 @@ class GremlinStorage(BaseGraphStorage):
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -16,6 +16,7 @@ from ..base import (
) )
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pymongo"): if not pm.is_installed("pymongo"):
@@ -598,6 +599,179 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# QUERY # QUERY
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
#
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
"""
label = node_label
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
try:
if label == "*":
# Get all nodes and edges
async for node_doc in self.collection.find({}):
node_id = str(node_doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_doc.get("_id")],
properties={
k: v
for k, v in node_doc.items()
if k not in ["_id", "edges"]
},
)
)
seen_nodes.add(node_id)
# Process edges
for edge in node_doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
else:
# Verify if starting node exists
start_nodes = self.collection.find({"_id": label})
start_nodes_exist = await start_nodes.to_list(length=1)
if not start_nodes_exist:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# Use $graphLookup for traversal
pipeline = [
{
"$match": {"_id": label}
}, # Start with nodes having the specified label
{
"$graphLookup": {
"from": self._collection_name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_nodes",
}
},
]
async for doc in self.collection.aggregate(pipeline):
# Add the start node
node_id = str(doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[
doc.get(
"_id",
)
],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"edges",
"connected_nodes",
"depth",
]
},
)
)
seen_nodes.add(node_id)
# Add edges from start node
for edge in doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
# Add connected nodes and their edges
for connected in doc.get("connected_nodes", []):
node_id = str(connected["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[connected.get("_id")],
properties={
k: v
for k, v in connected.items()
if k not in ["_id", "edges", "depth"]
},
)
)
seen_nodes.add(node_id)
# Add edges from connected nodes
for edge in connected.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except PyMongoError as e:
logger.error(f"MongoDB query failed: {str(e)}")
return result
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Mongo handles persistence automatically # Mongo handles persistence automatically

View File

@@ -17,6 +17,7 @@ from tenacity import (
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@@ -468,6 +469,99 @@ class Neo4JStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Key fixes:
1. Include the starting node itself
2. Handle multi-label nodes
3. Clarify relationship directions
4. Add depth control
"""
label = node_label.strip('"')
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE) as session:
try:
main_query = ""
if label == "*":
main_query = """
MATCH (n)
WITH collect(DISTINCT n) AS nodes
MATCH ()-[r]-()
RETURN nodes, collect(DISTINCT r) AS relationships;
"""
else:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
return result
# Optimized query (including direction handling and self-loops)
main_query = f"""
MATCH (start:`{label}`)
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',
minLevel: 0,
maxLevel: {max_depth},
bfs: true
}})
YIELD nodes, relationships
RETURN nodes, relationships
"""
result_set = await session.run(main_query)
record = await result_set.single()
if record:
# Handle nodes (compatible with multi-label cases)
for node in record["nodes"]:
# Use node ID + label combination as unique identifier
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=list(node.labels),
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except neo4jExceptions.ClientError as e:
logger.error(f"APOC query failed: {str(e)}")
return await self._robust_fallback(label, max_depth)
return result
async def _robust_fallback( async def _robust_fallback(
self, label: str, max_depth: int self, label: str, max_depth: int
) -> Dict[str, List[Dict]]: ) -> Dict[str, List[Dict]]:

View File

@@ -5,6 +5,7 @@ from typing import Any, final
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
) )
@@ -166,3 +167,8 @@ class NetworkXStorage(BaseGraphStorage):
for source, target in edges: for source, target in edges:
if self._graph.has_edge(source, target): if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target) self._graph.remove_edge(source, target)
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -8,6 +8,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
from lightrag.types import KnowledgeGraph
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@@ -669,6 +670,11 @@ class OracleGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -7,6 +7,7 @@ from typing import Any, Union, final
import numpy as np import numpy as np
import configparser import configparser
from lightrag.types import KnowledgeGraph
import sys import sys
from tenacity import ( from tenacity import (
@@ -1084,6 +1085,11 @@ class PGGraphStorage(BaseGraphStorage):
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"] drop_sql = SQL_TEMPLATES["drop_vdb_entity"]

View File

@@ -5,6 +5,8 @@ from typing import Any, Union, final
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
@@ -558,6 +560,11 @@ class TiDBGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -47,6 +47,7 @@ from .utils import (
set_logger, set_logger,
encode_string_by_tiktoken, encode_string_by_tiktoken,
) )
from .types import KnowledgeGraph
# TODO: TO REMOVE @Yannick # TODO: TO REMOVE @Yannick
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -457,6 +458,13 @@ class LightRAG:
self._storages_status = StoragesStatus.FINALIZED self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages") logger.debug("Finalized Storages")
async def get_knowledge_graph(
self, nodel_label: str, max_depth: int
) -> KnowledgeGraph:
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=nodel_label, max_depth=max_depth
)
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name] import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name) storage_class = lazy_external_import(import_path, storage_name)