updated clean of what implemented on DocStatusStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:53:59 +01:00
parent 71a18d1de9
commit 882190a515
9 changed files with 164 additions and 168 deletions

View File

@@ -51,11 +51,12 @@ Usage:
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
from typing import Any, cast
import networkx as nx
import numpy as np
from lightrag.types import KnowledgeGraph
from lightrag.utils import (
logger,
)
@@ -142,7 +143,7 @@ class NetworkXStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
@@ -151,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
@@ -162,35 +163,30 @@ class NetworkXStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
async def delete_node(self, node_id: str) -> None:
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -226,3 +222,9 @@ class NetworkXStorage(BaseGraphStorage):
for source, target in edges:
if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError