Revert "removed get_knowledge_graph"

This commit is contained in:
Yannick Stephan
2025-02-20 14:29:36 +01:00
committed by GitHub
parent 0d4c580859
commit 439685e69c
11 changed files with 324 additions and 0 deletions

View File

@@ -16,6 +16,7 @@ from ..base import (
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm
if not pm.is_installed("pymongo"):
@@ -598,6 +599,179 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
# QUERY
# -------------------------------------------------------------------------
#
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
"""
label = node_label
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
try:
if label == "*":
# Get all nodes and edges
async for node_doc in self.collection.find({}):
node_id = str(node_doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_doc.get("_id")],
properties={
k: v
for k, v in node_doc.items()
if k not in ["_id", "edges"]
},
)
)
seen_nodes.add(node_id)
# Process edges
for edge in node_doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
else:
# Verify if starting node exists
start_nodes = self.collection.find({"_id": label})
start_nodes_exist = await start_nodes.to_list(length=1)
if not start_nodes_exist:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# Use $graphLookup for traversal
pipeline = [
{
"$match": {"_id": label}
}, # Start with nodes having the specified label
{
"$graphLookup": {
"from": self._collection_name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_nodes",
}
},
]
async for doc in self.collection.aggregate(pipeline):
# Add the start node
node_id = str(doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[
doc.get(
"_id",
)
],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"edges",
"connected_nodes",
"depth",
]
},
)
)
seen_nodes.add(node_id)
# Add edges from start node
for edge in doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
# Add connected nodes and their edges
for connected in doc.get("connected_nodes", []):
node_id = str(connected["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[connected.get("_id")],
properties={
k: v
for k, v in connected.items()
if k not in ["_id", "edges", "depth"]
},
)
)
seen_nodes.add(node_id)
# Add edges from connected nodes
for edge in connected.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except PyMongoError as e:
logger.error(f"MongoDB query failed: {str(e)}")
return result
async def index_done_callback(self) -> None:
# Mongo handles persistence automatically