updated clean of what implemented on DocStatusStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:53:59 +01:00
parent 71a18d1de9
commit 882190a515
9 changed files with 164 additions and 168 deletions

View File

@@ -92,22 +92,20 @@ class StorageNameSpace:
class BaseVectorStorage(StorageNameSpace): class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
meta_fields: set[str] = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results."""
raise NotImplementedError raise NotImplementedError
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Use 'content' field from value for embedding, use key as id. """Insert or update vectors in the storage."""
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError raise NotImplementedError
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name""" """Delete a single entity by its name."""
raise NotImplementedError raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity."""
raise NotImplementedError raise NotImplementedError
@@ -116,9 +114,11 @@ class BaseKVStorage(StorageNameSpace):
embedding_func: EmbeddingFunc | None = None embedding_func: EmbeddingFunc | None = None
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get value by id"""
raise NotImplementedError raise NotImplementedError
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get values by ids"""
raise NotImplementedError raise NotImplementedError
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
@@ -126,9 +126,11 @@ class BaseKVStorage(StorageNameSpace):
raise NotImplementedError raise NotImplementedError
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Upsert data"""
raise NotImplementedError raise NotImplementedError
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage"""
raise NotImplementedError raise NotImplementedError
@@ -138,74 +140,62 @@ class BaseGraphStorage(StorageNameSpace):
"""Check if a node exists in the graph.""" """Check if a node exists in the graph."""
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
"""Check if an edge exists in the graph.""" """Check if an edge exists in the graph."""
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
"""Get the degree of a node.""" """Get the degree of a node."""
raise NotImplementedError
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
"""Get the degree of an edge.""" """Get the degree of an edge."""
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
"""Get a node by its id.""" """Get a node by its id."""
raise NotImplementedError
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
raise NotImplementedError
"""Get an edge by its source and target node ids.""" """Get an edge by its source and target node ids."""
raise NotImplementedError
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get all edges connected to a node."""
raise NotImplementedError raise NotImplementedError
"""Get all edges connected to a node."""
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
raise NotImplementedError
"""Upsert a node into the graph.""" """Upsert a node into the graph."""
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError
"""Upsert an edge into the graph.""" """Upsert an edge into the graph."""
raise NotImplementedError
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
raise NotImplementedError
"""Delete a node from the graph.""" """Delete a node from the graph."""
raise NotImplementedError
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
"""Embed nodes using an algorithm.""" """Embed nodes using an algorithm."""
raise NotImplementedError
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
"""Get all labels in the graph."""
raise NotImplementedError("Node embedding is not used in lightrag.") raise NotImplementedError("Node embedding is not used in lightrag.")
"""Get all labels in the graph."""
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
"""Get a knowledge graph of a node."""
raise NotImplementedError raise NotImplementedError
"""Get a knowledge graph of a node.""" async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError raise NotImplementedError

View File

@@ -5,7 +5,8 @@ import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("psycopg-pool"): if not pm.is_installed("psycopg-pool"):
@@ -15,6 +16,7 @@ if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
from lightrag.types import KnowledgeGraph
import psycopg import psycopg
from psycopg.rows import namedtuple_row from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout from psycopg_pool import AsyncConnectionPool, PoolTimeout
@@ -396,7 +398,7 @@ class AGEStorage(BaseGraphStorage):
) )
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
query = """ query = """
MATCH (n:`{label}`) RETURN n MATCH (n:`{label}`) RETURN n
@@ -454,17 +456,7 @@ class AGEStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Find all edges between nodes of two given labels
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
@@ -488,7 +480,7 @@ class AGEStorage(BaseGraphStorage):
) )
return result return result
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
@@ -526,7 +518,7 @@ class AGEStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((AGEQueryException,)), retry=retry_if_exception_type((AGEQueryException,)),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the AGE database. Upsert a node in the AGE database.
@@ -562,8 +554,8 @@ class AGEStorage(BaseGraphStorage):
retry=retry_if_exception_type((AGEQueryException,)), retry=retry_if_exception_type((AGEQueryException,)),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -619,3 +611,15 @@ class AGEStorage(BaseGraphStorage):
yield connection yield connection
finally: finally:
await self._driver.putconn(connection) await self._driver.putconn(connection)
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -3,7 +3,9 @@ import inspect
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List
import numpy as np
from gremlin_python.driver import client, serializer from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport from gremlin_python.driver.aiohttp.transport import AiohttpTransport
@@ -15,6 +17,7 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from lightrag.types import KnowledgeGraph
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
@@ -190,7 +193,7 @@ class GremlinStorage(BaseGraphStorage):
return result[0]["has_edge"] return result[0]["has_edge"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name = GremlinStorage._fix_name(node_id) entity_name = GremlinStorage._fix_name(node_id)
query = f"""g query = f"""g
.V().has('graph', {self.graph_name}) .V().has('graph', {self.graph_name})
@@ -252,17 +255,7 @@ class GremlinStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Find all edges between nodes of two given names
Args:
source_node_id (str): Name of the source nodes
target_node_id (str): Name of the target nodes
Returns:
dict|None: Dict of found edge properties, or None if not found
"""
entity_name_source = GremlinStorage._fix_name(source_node_id) entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id) entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g query = f"""g
@@ -286,11 +279,7 @@ class GremlinStorage(BaseGraphStorage):
) )
return edge_properties return edge_properties
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Retrieves all edges (relationships) for a particular node identified by its name.
:return: List of tuples containing edge sources and targets
"""
node_name = GremlinStorage._fix_name(source_node_id) node_name = GremlinStorage._fix_name(source_node_id)
query = f"""g query = f"""g
.E() .E()
@@ -316,7 +305,7 @@ class GremlinStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)), retry=retry_if_exception_type((GremlinServerError,)),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the Gremlin graph. Upsert a node in the Gremlin graph.
@@ -357,8 +346,8 @@ class GremlinStorage(BaseGraphStorage):
retry=retry_if_exception_type((GremlinServerError,)), retry=retry_if_exception_type((GremlinServerError,)),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their names. Upsert an edge and its properties between two nodes identified by their names.
@@ -397,3 +386,17 @@ class GremlinStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -12,7 +12,7 @@ if not pm.is_installed("pymongo"):
if not pm.is_installed("motor"): if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
from typing import Any, List, Tuple, Union from typing import Any, List, Union
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.operations import SearchIndexModel from pymongo.operations import SearchIndexModel
@@ -448,7 +448,7 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
""" """
Return the full node document (including "edges"), or None if missing. Return the full node document (including "edges"), or None if missing.
""" """
@@ -456,11 +456,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Return the first edge dict from source_node_id to target_node_id if it exists.
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
"""
pipeline = [ pipeline = [
{"$match": {"_id": source_node_id}}, {"$match": {"_id": source_node_id}},
{ {
@@ -486,9 +482,7 @@ class MongoGraphStorage(BaseGraphStorage):
return e return e
return None return None
async def get_node_edges( async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
self, source_node_id: str
) -> Union[List[Tuple[str, str]], None]:
""" """
Return a list of (source_id, target_id) for direct edges from source_node_id. Return a list of (source_id, target_id) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
@@ -522,7 +516,7 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def upsert_node(self, node_id: str, node_data: dict): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Insert or update a node document. If new, create an empty edges array. Insert or update a node document. If new, create an empty edges array.
""" """
@@ -532,8 +526,8 @@ class MongoGraphStorage(BaseGraphStorage):
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge from source_node_id -> target_node_id with optional 'relation'. Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data. If an edge with the same target exists, we remove it and re-insert with updated data.
@@ -559,7 +553,7 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def delete_node(self, node_id: str): async def delete_node(self, node_id: str) -> None:
""" """
1) Remove node's doc entirely. 1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id. 2) Remove inbound edges from any doc that references node_id.
@@ -576,7 +570,7 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
""" """
Placeholder for demonstration, raises NotImplementedError. Placeholder for demonstration, raises NotImplementedError.
""" """
@@ -606,9 +600,7 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"]) labels.append(doc["_id"])
return labels return labels
async def get_knowledge_graph( async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)

View File

@@ -3,7 +3,8 @@ import inspect
import os import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict from typing import Any, List, Dict
import numpy as np
import pipmaster as pm import pipmaster as pm
import configparser import configparser
@@ -191,7 +192,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
return single_result["edgeExists"] return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier. """Get node by its label identifier.
Args: Args:
@@ -252,17 +253,8 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""Find edge between two nodes identified by their labels.
Args:
source_node_id (str): Label of the source node
target_node_id (str): Label of the target node
Returns:
dict: Edge properties if found, with at least {"weight": 0.0}
None: If error occurs
"""
try: try:
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
@@ -321,7 +313,7 @@ class Neo4JStorage(BaseGraphStorage):
# Return default edge properties on error # Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None} return {"weight": 0.0, "source_id": None, "target_id": None}
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
node_label = source_node_id.strip('"') node_label = source_node_id.strip('"')
""" """
@@ -364,7 +356,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
), ),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the Neo4j database. Upsert a node in the Neo4j database.
@@ -405,8 +397,8 @@ class Neo4JStorage(BaseGraphStorage):
), ),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -444,9 +436,7 @@ class Neo4JStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def get_knowledge_graph( async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
@@ -603,7 +593,7 @@ class Neo4JStorage(BaseGraphStorage):
await traverse(label, 0) await traverse(label, 0)
return result return result
async def get_all_labels(self) -> List[str]: async def get_all_labels(self) -> list[str]:
""" """
Get all existing node labels in the database Get all existing node labels in the database
Returns: Returns:
@@ -627,3 +617,11 @@ class Neo4JStorage(BaseGraphStorage):
async for record in result: async for record in result:
labels.append(record["label"]) labels.append(record["label"])
return labels return labels
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError

View File

@@ -51,11 +51,12 @@ Usage:
import html import html
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, cast from typing import Any, cast
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
) )
@@ -142,7 +143,7 @@ class NetworkXStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def index_done_callback(self): async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
@@ -151,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id) return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id) return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
@@ -162,35 +163,30 @@ class NetworkXStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id)) return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id): if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id)) return list(self._graph.edges(source_node_id))
return None return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data) self._graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data) self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str): async def delete_node(self, node_id: str) -> None:
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
if self._graph.has_node(node_id): if self._graph.has_node(node_id):
self._graph.remove_node(node_id) self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.") logger.info(f"Node {node_id} deleted from the graph.")
else: else:
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -226,3 +222,9 @@ class NetworkXStorage(BaseGraphStorage):
for source, target in edges: for source, target in edges:
if self._graph.has_edge(source, target): if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target) self._graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -13,6 +13,7 @@ if not pm.is_installed("oracledb"):
pm.install("oracledb") pm.install("oracledb")
from lightrag.types import KnowledgeGraph
import oracledb import oracledb
from ..base import ( from ..base import (
@@ -378,9 +379,7 @@ class OracleGraphStorage(BaseGraphStorage):
#################### insert method ################ #################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""插入或更新节点"""
# print("go into upsert node method")
entity_name = node_id entity_name = node_id
entity_type = node_data["entity_type"] entity_type = node_data["entity_type"]
description = node_data["description"] description = node_data["description"]
@@ -413,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
"""插入或更新边""" """插入或更新边"""
# print("go into upsert edge method") # print("go into upsert edge method")
source_name = source_node_id source_name = source_node_id
@@ -453,8 +452,7 @@ class OracleGraphStorage(BaseGraphStorage):
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
"""为节点生成向量"""
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -471,7 +469,7 @@ class OracleGraphStorage(BaseGraphStorage):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""写入graphhml图文件""" """写入graphhml图文件"""
logger.info( logger.info(
"Node and edge data had been saved into oracle db already, so nothing to do here!" "Node and edge data had been saved into oracle db already, so nothing to do here!"
@@ -493,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage):
return False return False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"] SQL = SQL_TEMPLATES["has_edge"]
params = { params = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
@@ -510,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage):
return False return False
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"] SQL = SQL_TEMPLATES["node_degree"]
params = {"workspace": self.db.workspace, "node_id": node_id} params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL) # print(SQL)
@@ -528,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Edge degree",degree) # print("Edge degree",degree)
return degree return degree
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""根据节点id获取节点数据""" """根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"] SQL = SQL_TEMPLATES["get_node"]
params = {"workspace": self.db.workspace, "node_id": node_id} params = {"workspace": self.db.workspace, "node_id": node_id}
@@ -544,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""根据源和目标节点id获取边"""
SQL = SQL_TEMPLATES["get_edge"] SQL = SQL_TEMPLATES["get_edge"]
params = { params = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
@@ -560,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
return None return None
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""根据节点id获取节点的所有边"""
if await self.has_node(source_node_id): if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"] SQL = SQL_TEMPLATES["get_node_edges"]
params = {"workspace": self.db.workspace, "source_node_id": source_node_id} params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
@@ -597,6 +591,14 @@ class OracleGraphStorage(BaseGraphStorage):
if res: if res:
return res return res
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -4,11 +4,13 @@ import json
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import pipmaster as pm import pipmaster as pm
from lightrag.types import KnowledgeGraph
if not pm.is_installed("asyncpg"): if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
@@ -835,7 +837,7 @@ class PGGraphStorage(BaseGraphStorage):
) )
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"')) label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {node_id: "%s"})
@@ -890,17 +892,7 @@ class PGGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
"""
Find all edges between nodes of two given labels
Args:
source_node_id (str): Label of the source nodes
target_node_id (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
@@ -924,7 +916,7 @@ class PGGraphStorage(BaseGraphStorage):
) )
return result return result
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
@@ -972,14 +964,7 @@ class PGGraphStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((PGGraphQueryException,)), retry=retry_if_exception_type((PGGraphQueryException,)),
) )
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the AGE database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = PGGraphStorage._encode_graph_label(node_id.strip('"')) label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
properties = node_data properties = node_data
@@ -1010,8 +995,8 @@ class PGGraphStorage(BaseGraphStorage):
retry=retry_if_exception_type((PGGraphQueryException,)), retry=retry_if_exception_type((PGGraphQueryException,)),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.
@@ -1053,6 +1038,19 @@ class PGGraphStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError
NAMESPACE_TABLE_MAP = { NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"): if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy") pm.install("sqlalchemy")
from lightrag.types import KnowledgeGraph
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from tqdm import tqdm from tqdm import tqdm
@@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
#################### upsert method ################ #################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id entity_name = node_id
entity_type = node_data["entity_type"] entity_type = node_data["entity_type"]
description = node_data["description"] description = node_data["description"]
@@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage):
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
): ) -> None:
source_name = source_node_id source_name = source_node_id
target_name = target_node_id target_name = target_node_id
weight = edge_data["weight"] weight = edge_data["weight"]
@@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage):
} }
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage):
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
return degree return degree
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_node"] sql = SQL_TEMPLATES["get_node"]
param = {"name": node_id, "workspace": self.db.workspace} param = {"name": node_id, "workspace": self.db.workspace}
return await self.db.query(sql, param) return await self.db.query(sql, param)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_edge"] sql = SQL_TEMPLATES["get_edge"]
param = { param = {
"source_name": source_node_id, "source_name": source_node_id,
@@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage):
} }
return await self.db.query(sql, param) return await self.db.query(sql, param)
async def get_node_edges( async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
sql = SQL_TEMPLATES["get_node_edges"] sql = SQL_TEMPLATES["get_node_edges"]
param = {"source_name": source_node_id, "workspace": self.db.workspace} param = {"source_name": source_node_id, "workspace": self.db.workspace}
res = await self.db.query(sql, param, multirows=True) res = await self.db.query(sql, param, multirows=True)
@@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage):
else: else:
return [] return []
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",