Merge pull request #901 from HKUDS/revert-886-clean-2
Revert "removed get_knowledge_graph"
This commit is contained in:
@@ -1683,6 +1683,10 @@ def create_app(args):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# query all graph
|
||||
@app.get("/graphs")
|
||||
async def get_knowledge_graph(label: str):
|
||||
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
|
||||
|
||||
# Add Ollama API routes
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||
app.include_router(ollama_api.router, prefix="/api")
|
||||
|
@@ -13,6 +13,7 @@ from typing import (
|
||||
)
|
||||
import numpy as np
|
||||
from .utils import EmbeddingFunc
|
||||
from .types import KnowledgeGraph
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -197,6 +198,12 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
"""Get all labels in the graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||
|
||||
|
||||
class DocStatus(str, Enum):
|
||||
"""Document processing status"""
|
||||
|
@@ -8,6 +8,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
@@ -615,6 +616,11 @@ class AGEStorage(BaseGraphStorage):
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# AGES handles persistence automatically
|
||||
pass
|
||||
|
@@ -16,6 +16,7 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.utils import logger
|
||||
|
||||
from ..base import BaseGraphStorage
|
||||
@@ -401,3 +402,8 @@ class GremlinStorage(BaseGraphStorage):
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
@@ -16,6 +16,7 @@ from ..base import (
|
||||
)
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymongo"):
|
||||
@@ -598,6 +599,179 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
# -------------------------------------------------------------------------
|
||||
# QUERY
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
||||
Args:
|
||||
node_label: Label of the nodes to start from
|
||||
max_depth: Maximum depth of traversal (default: 5)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||
"""
|
||||
label = node_label
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
try:
|
||||
if label == "*":
|
||||
# Get all nodes and edges
|
||||
async for node_doc in self.collection.find({}):
|
||||
node_id = str(node_doc["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_doc.get("_id")],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in node_doc.items()
|
||||
if k not in ["_id", "edges"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Process edges
|
||||
for edge in node_doc.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
else:
|
||||
# Verify if starting node exists
|
||||
start_nodes = self.collection.find({"_id": label})
|
||||
start_nodes_exist = await start_nodes.to_list(length=1)
|
||||
if not start_nodes_exist:
|
||||
logger.warning(f"Starting node with label {label} does not exist!")
|
||||
return result
|
||||
|
||||
# Use $graphLookup for traversal
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {"_id": label}
|
||||
}, # Start with nodes having the specified label
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._collection_name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_nodes",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
async for doc in self.collection.aggregate(pipeline):
|
||||
# Add the start node
|
||||
node_id = str(doc["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[
|
||||
doc.get(
|
||||
"_id",
|
||||
)
|
||||
],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in doc.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"edges",
|
||||
"connected_nodes",
|
||||
"depth",
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Add edges from start node
|
||||
for edge in doc.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
# Add connected nodes and their edges
|
||||
for connected in doc.get("connected_nodes", []):
|
||||
node_id = str(connected["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[connected.get("_id")],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in connected.items()
|
||||
if k not in ["_id", "edges", "depth"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Add edges from connected nodes
|
||||
for edge in connected.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"MongoDB query failed: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Mongo handles persistence automatically
|
||||
|
@@ -17,6 +17,7 @@ from tenacity import (
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("neo4j"):
|
||||
@@ -468,6 +469,99 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
||||
Key fixes:
|
||||
1. Include the starting node itself
|
||||
2. Handle multi-label nodes
|
||||
3. Clarify relationship directions
|
||||
4. Add depth control
|
||||
"""
|
||||
label = node_label.strip('"')
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
try:
|
||||
main_query = ""
|
||||
if label == "*":
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
WITH collect(DISTINCT n) AS nodes
|
||||
MATCH ()-[r]-()
|
||||
RETURN nodes, collect(DISTINCT r) AS relationships;
|
||||
"""
|
||||
else:
|
||||
# Critical debug step: first verify if starting node exists
|
||||
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
||||
validate_result = await session.run(validate_query)
|
||||
if not await validate_result.single():
|
||||
logger.warning(f"Starting node {label} does not exist!")
|
||||
return result
|
||||
|
||||
# Optimized query (including direction handling and self-loops)
|
||||
main_query = f"""
|
||||
MATCH (start:`{label}`)
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {{
|
||||
relationshipFilter: '>',
|
||||
minLevel: 0,
|
||||
maxLevel: {max_depth},
|
||||
bfs: true
|
||||
}})
|
||||
YIELD nodes, relationships
|
||||
RETURN nodes, relationships
|
||||
"""
|
||||
result_set = await session.run(main_query)
|
||||
record = await result_set.single()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
for node in record["nodes"]:
|
||||
# Use node ID + label combination as unique identifier
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=list(node.labels),
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Handle relationships (including direction information)
|
||||
for rel in record["relationships"]:
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
logger.error(f"APOC query failed: {str(e)}")
|
||||
return await self._robust_fallback(label, max_depth)
|
||||
|
||||
return result
|
||||
|
||||
async def _robust_fallback(
|
||||
self, label: str, max_depth: int
|
||||
) -> Dict[str, List[Dict]]:
|
||||
|
@@ -5,6 +5,7 @@ from typing import Any, final
|
||||
import numpy as np
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
)
|
||||
@@ -166,3 +167,8 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
for source, target in edges:
|
||||
if self._graph.has_edge(source, target):
|
||||
self._graph.remove_edge(source, target)
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
@@ -8,6 +8,7 @@ from typing import Any, Union, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
@@ -669,6 +670,11 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
N_T = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
@@ -7,6 +7,7 @@ from typing import Any, Union, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
import sys
|
||||
from tenacity import (
|
||||
@@ -1084,6 +1085,11 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
|
||||
|
@@ -5,6 +5,8 @@ from typing import Any, Union, final
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
|
||||
|
||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
@@ -558,6 +560,11 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
N_T = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
@@ -47,6 +47,7 @@ from .utils import (
|
||||
set_logger,
|
||||
encode_string_by_tiktoken,
|
||||
)
|
||||
from .types import KnowledgeGraph
|
||||
|
||||
# TODO: TO REMOVE @Yannick
|
||||
config = configparser.ConfigParser()
|
||||
@@ -457,6 +458,13 @@ class LightRAG:
|
||||
self._storages_status = StoragesStatus.FINALIZED
|
||||
logger.debug("Finalized Storages")
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, nodel_label: str, max_depth: int
|
||||
) -> KnowledgeGraph:
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||
node_label=nodel_label, max_depth=max_depth
|
||||
)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
import_path = STORAGES[storage_name]
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
|
Reference in New Issue
Block a user