diff --git a/lightrag/base.py b/lightrag/base.py index 8efbe8a2..3cc7646d 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -92,22 +92,20 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + """Query the vector storage and retrieve top_k results.""" raise NotImplementedError async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - """Use 'content' field from value for embedding, use key as id. - If embedding_func is None, use 'embedding' field from value - """ + """Insert or update vectors in the storage.""" raise NotImplementedError async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" + """Delete a single entity by its name.""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" + """Delete relations for a given entity.""" raise NotImplementedError @@ -116,9 +114,11 @@ class BaseKVStorage(StorageNameSpace): embedding_func: EmbeddingFunc | None = None async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get value by id""" raise NotImplementedError async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get values by ids""" raise NotImplementedError async def filter_keys(self, keys: set[str]) -> set[str]: @@ -126,9 +126,11 @@ class BaseKVStorage(StorageNameSpace): raise NotImplementedError async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Upsert data""" raise NotImplementedError async def drop(self) -> None: + """Drop the storage""" raise NotImplementedError @@ -138,74 +140,62 @@ class BaseGraphStorage(StorageNameSpace): """Check if a node exists in the graph.""" async def has_node(self, node_id: str) -> bool: + """Check if an edge exists in the graph.""" raise NotImplementedError - """Check if an edge exists in the graph.""" - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """Get the degree of a node.""" raise NotImplementedError - """Get the degree of a node.""" - async def node_degree(self, node_id: str) -> int: + """Get the degree of an edge.""" raise NotImplementedError - """Get the degree of an edge.""" - async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get a node by its id.""" raise NotImplementedError - """Get a node by its id.""" - async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get an edge by its source and target node ids.""" raise NotImplementedError - """Get an edge by its source and target node ids.""" - async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get all edges connected to a node.""" raise NotImplementedError - """Get all edges connected to a node.""" async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """Upsert a node into the graph.""" raise NotImplementedError - """Upsert a node into the graph.""" - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """Upsert an edge into the graph.""" raise NotImplementedError - """Upsert an edge into the graph.""" - async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: + """Delete a node from the graph.""" raise NotImplementedError - """Delete a node from the graph.""" - async def delete_node(self, node_id: str) -> None: + """Embed nodes using an algorithm.""" raise NotImplementedError - """Embed nodes using an algorithm.""" - async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: + """Get all labels in the graph.""" raise NotImplementedError("Node embedding is not used in lightrag.") - """Get all labels in the graph.""" - async def get_all_labels(self) -> list[str]: + """Get a knowledge graph of a node.""" raise NotImplementedError - """Get a knowledge graph of a node.""" - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + """Retrieve a subgraph of the knowledge graph starting from a given node.""" raise NotImplementedError diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index a6857f22..a64e4785 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -5,7 +5,8 @@ import os import sys from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union +import numpy as np import pipmaster as pm if not pm.is_installed("psycopg-pool"): @@ -15,6 +16,7 @@ if not pm.is_installed("asyncpg"): pm.install("asyncpg") +from lightrag.types import KnowledgeGraph import psycopg from psycopg.rows import namedtuple_row from psycopg_pool import AsyncConnectionPool, PoolTimeout @@ -396,7 +398,7 @@ class AGEStorage(BaseGraphStorage): ) return single_result["edge_exists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: entity_name_label = node_id.strip('"') query = """ MATCH (n:`{label}`) RETURN n @@ -454,17 +456,7 @@ class AGEStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given labels - - Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ + ) -> dict[str, str] | None: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -488,7 +480,7 @@ class AGEStorage(BaseGraphStorage): ) return result - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information @@ -526,7 +518,7 @@ class AGEStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((AGEQueryException,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the AGE database. @@ -562,8 +554,8 @@ class AGEStorage(BaseGraphStorage): retry=retry_if_exception_type((AGEQueryException,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -619,3 +611,15 @@ class AGEStorage(BaseGraphStorage): yield connection finally: await self._driver.putconn(connection) + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index f38fd00a..77c627b6 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -3,7 +3,9 @@ import inspect import json import os from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List + +import numpy as np from gremlin_python.driver import client, serializer from gremlin_python.driver.aiohttp.transport import AiohttpTransport @@ -15,6 +17,7 @@ from tenacity import ( wait_exponential, ) +from lightrag.types import KnowledgeGraph from lightrag.utils import logger from ..base import BaseGraphStorage @@ -190,7 +193,7 @@ class GremlinStorage(BaseGraphStorage): return result[0]["has_edge"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: entity_name = GremlinStorage._fix_name(node_id) query = f"""g .V().has('graph', {self.graph_name}) @@ -252,17 +255,7 @@ class GremlinStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given names - - Args: - source_node_id (str): Name of the source nodes - target_node_id (str): Name of the target nodes - - Returns: - dict|None: Dict of found edge properties, or None if not found - """ + ) -> dict[str, str] | None: entity_name_source = GremlinStorage._fix_name(source_node_id) entity_name_target = GremlinStorage._fix_name(target_node_id) query = f"""g @@ -286,11 +279,7 @@ class GremlinStorage(BaseGraphStorage): ) return edge_properties - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: - """ - Retrieves all edges (relationships) for a particular node identified by its name. - :return: List of tuples containing edge sources and targets - """ + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: node_name = GremlinStorage._fix_name(source_node_id) query = f"""g .E() @@ -316,7 +305,7 @@ class GremlinStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((GremlinServerError,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the Gremlin graph. @@ -357,8 +346,8 @@ class GremlinStorage(BaseGraphStorage): retry=retry_if_exception_type((GremlinServerError,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their names. @@ -397,3 +386,17 @@ class GremlinStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 44820ecf..ce15fe29 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -12,7 +12,7 @@ if not pm.is_installed("pymongo"): if not pm.is_installed("motor"): pm.install("motor") -from typing import Any, List, Tuple, Union +from typing import Any, List, Union from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient from pymongo.operations import SearchIndexModel @@ -448,7 +448,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """ Return the full node document (including "edges"), or None if missing. """ @@ -456,11 +456,7 @@ class MongoGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Return the first edge dict from source_node_id to target_node_id if it exists. - Uses a single-hop $graphLookup as demonstration, though a direct find is simpler. - """ + ) -> dict[str, str] | None: pipeline = [ {"$match": {"_id": source_node_id}}, { @@ -486,9 +482,7 @@ class MongoGraphStorage(BaseGraphStorage): return e return None - async def get_node_edges( - self, source_node_id: str - ) -> Union[List[Tuple[str, str]], None]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Return a list of (source_id, target_id) for direct edges from source_node_id. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. @@ -522,7 +516,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def upsert_node(self, node_id: str, node_data: dict): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Insert or update a node document. If new, create an empty edges array. """ @@ -532,8 +526,8 @@ class MongoGraphStorage(BaseGraphStorage): await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge from source_node_id -> target_node_id with optional 'relation'. If an edge with the same target exists, we remove it and re-insert with updated data. @@ -559,7 +553,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def delete_node(self, node_id: str): + async def delete_node(self, node_id: str) -> None: """ 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. @@ -576,7 +570,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: """ Placeholder for demonstration, raises NotImplementedError. """ @@ -606,9 +600,7 @@ class MongoGraphStorage(BaseGraphStorage): labels.append(doc["_id"]) return labels - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 15525375..f27a9645 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,7 +3,8 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, Union, Tuple, List, Dict +from typing import Any, List, Dict +import numpy as np import pipmaster as pm import configparser @@ -191,7 +192,7 @@ class Neo4JStorage(BaseGraphStorage): ) return single_result["edgeExists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier. Args: @@ -252,17 +253,8 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """Find edge between two nodes identified by their labels. + ) -> dict[str, str] | None: - Args: - source_node_id (str): Label of the source node - target_node_id (str): Label of the target node - - Returns: - dict: Edge properties if found, with at least {"weight": 0.0} - None: If error occurs - """ try: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -321,7 +313,7 @@ class Neo4JStorage(BaseGraphStorage): # Return default edge properties on error return {"weight": 0.0, "source_id": None, "target_id": None} - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: node_label = source_node_id.strip('"') """ @@ -364,7 +356,7 @@ class Neo4JStorage(BaseGraphStorage): ) ), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the Neo4j database. @@ -405,8 +397,8 @@ class Neo4JStorage(BaseGraphStorage): ), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -444,9 +436,7 @@ class Neo4JStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -603,7 +593,7 @@ class Neo4JStorage(BaseGraphStorage): await traverse(label, 0) return result - async def get_all_labels(self) -> List[str]: + async def get_all_labels(self) -> list[str]: """ Get all existing node labels in the database Returns: @@ -627,3 +617,11 @@ class Neo4JStorage(BaseGraphStorage): async for record in result: labels.append(record["label"]) return labels + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index bb84cf82..254bb0ed 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -51,11 +51,12 @@ Usage: import html import os from dataclasses import dataclass -from typing import Any, Union, cast +from typing import Any, cast import networkx as nx import numpy as np +from lightrag.types import KnowledgeGraph from lightrag.utils import ( logger, ) @@ -142,7 +143,7 @@ class NetworkXStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } - async def index_done_callback(self): + async def index_done_callback(self) -> None: NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: @@ -151,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: return self._graph.has_edge(source_node_id, target_node_id) - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: return self._graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: @@ -162,35 +163,30 @@ class NetworkXStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> dict[str, str] | None: return self._graph.edges.get((source_node_id, target_node_id)) - async def get_node_edges(self, source_node_id: str): + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id)) return None - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: self._graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def delete_node(self, node_id: str): - """ - Delete a node from the graph based on the specified node_id. - - :param node_id: The node_id to delete - """ + async def delete_node(self, node_id: str) -> None: if self._graph.has_node(node_id): self._graph.remove_node(node_id) logger.info(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -226,3 +222,9 @@ class NetworkXStorage(BaseGraphStorage): for source, target in edges: if self._graph.has_edge(source, target): self._graph.remove_edge(source, target) + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 95d888b3..360a4847 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -13,6 +13,7 @@ if not pm.is_installed("oracledb"): pm.install("oracledb") +from lightrag.types import KnowledgeGraph import oracledb from ..base import ( @@ -378,9 +379,7 @@ class OracleGraphStorage(BaseGraphStorage): #################### insert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - """插入或更新节点""" - # print("go into upsert node method") + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] @@ -413,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: """插入或更新边""" # print("go into upsert edge method") source_name = source_node_id @@ -453,8 +452,7 @@ class OracleGraphStorage(BaseGraphStorage): await self.db.execute(merge_sql, data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - """为节点生成向量""" + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -471,7 +469,7 @@ class OracleGraphStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - async def index_done_callback(self): + async def index_done_callback(self) -> None: """写入graphhml图文件""" logger.info( "Node and edge data had been saved into oracle db already, so nothing to do here!" @@ -493,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage): return False async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - """根据源和目标节点id检查边是否存在""" SQL = SQL_TEMPLATES["has_edge"] params = { "workspace": self.db.workspace, @@ -510,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage): return False async def node_degree(self, node_id: str) -> int: - """根据节点id获取节点的度""" SQL = SQL_TEMPLATES["node_degree"] params = {"workspace": self.db.workspace, "node_id": node_id} # print(SQL) @@ -528,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage): # print("Edge degree",degree) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """根据节点id获取节点数据""" SQL = SQL_TEMPLATES["get_node"] params = {"workspace": self.db.workspace, "node_id": node_id} @@ -544,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """根据源和目标节点id获取边""" + ) -> dict[str, str] | None: SQL = SQL_TEMPLATES["get_edge"] params = { "workspace": self.db.workspace, @@ -560,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage): # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) return None - async def get_node_edges(self, source_node_id: str): - """根据节点id获取节点的所有边""" + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: if await self.has_node(source_node_id): SQL = SQL_TEMPLATES["get_node_edges"] params = {"workspace": self.db.workspace, "source_node_id": source_node_id} @@ -597,6 +591,14 @@ class OracleGraphStorage(BaseGraphStorage): if res: return res + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 98f9c495..47336190 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,11 +4,13 @@ import json import os import time from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union import numpy as np import pipmaster as pm +from lightrag.types import KnowledgeGraph + if not pm.is_installed("asyncpg"): pm.install("asyncpg") @@ -835,7 +837,7 @@ class PGGraphStorage(BaseGraphStorage): ) return single_result["edge_exists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: label = PGGraphStorage._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"}) @@ -890,17 +892,7 @@ class PGGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given labels - - Args: - source_node_id (str): Label of the source nodes - target_node_id (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ + ) -> dict[str, str] | None: src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) @@ -924,7 +916,7 @@ class PGGraphStorage(BaseGraphStorage): ) return result - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information @@ -972,14 +964,7 @@ class PGGraphStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((PGGraphQueryException,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): - """ - Upsert a node in the AGE database. - - Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties - """ + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: label = PGGraphStorage._encode_graph_label(node_id.strip('"')) properties = node_data @@ -1010,8 +995,8 @@ class PGGraphStorage(BaseGraphStorage): retry=retry_if_exception_type((PGGraphQueryException,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -1053,6 +1038,19 @@ class PGGraphStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 6f388e7f..44c0d9e7 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"): if not pm.is_installed("sqlalchemy"): pm.install("sqlalchemy") +from lightrag.types import KnowledgeGraph from sqlalchemy import create_engine, text from tqdm import tqdm @@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage): self._max_batch_size = self.global_config["embedding_batch_num"] #################### upsert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] @@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: source_name = source_node_id target_name = target_node_id weight = edge_data["weight"] @@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage): } await self.db.execute(merge_sql, data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage): degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_node"] param = {"name": node_id, "workspace": self.db.workspace} return await self.db.query(sql, param) async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_edge"] param = { "source_name": source_node_id, @@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage): } return await self.db.query(sql, param) - async def get_node_edges( - self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: sql = SQL_TEMPLATES["get_node_edges"] param = {"source_name": source_node_id, "workspace": self.db.workspace} res = await self.db.query(sql, param, multirows=True) @@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage): else: return [] + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + raise NotImplementedError N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",