updated clean of what implemented on DocStatusStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:53:59 +01:00
parent 71a18d1de9
commit 882190a515
9 changed files with 164 additions and 168 deletions

View File

@@ -12,7 +12,7 @@ if not pm.is_installed("pymongo"):
if not pm.is_installed("motor"):
pm.install("motor")
from typing import Any, List, Tuple, Union
from typing import Any, List, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
@@ -448,7 +448,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""
Return the full node document (including "edges"), or None if missing.
"""
@@ -456,11 +456,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Return the first edge dict from source_node_id to target_node_id if it exists.
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
"""
) -> dict[str, str] | None:
pipeline = [
{"$match": {"_id": source_node_id}},
{
@@ -486,9 +482,7 @@ class MongoGraphStorage(BaseGraphStorage):
return e
return None
async def get_node_edges(
self, source_node_id: str
) -> Union[List[Tuple[str, str]], None]:
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Return a list of (source_id, target_id) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
@@ -522,7 +516,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def upsert_node(self, node_id: str, node_data: dict):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Insert or update a node document. If new, create an empty edges array.
"""
@@ -532,8 +526,8 @@ class MongoGraphStorage(BaseGraphStorage):
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict
):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data.
@@ -559,7 +553,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def delete_node(self, node_id: str):
async def delete_node(self, node_id: str) -> None:
"""
1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id.
@@ -576,7 +570,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
"""
Placeholder for demonstration, raises NotImplementedError.
"""
@@ -606,9 +600,7 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"])
return labels
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)