updated clean of what implemented on DocStatusStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:53:59 +01:00
parent 71a18d1de9
commit 882190a515
9 changed files with 164 additions and 168 deletions

View File

@@ -92,22 +92,20 @@ class StorageNameSpace:
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results."""
raise NotImplementedError
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""
"""Insert or update vectors in the storage."""
raise NotImplementedError
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
"""Delete a single entity by its name."""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
"""Delete relations for a given entity."""
raise NotImplementedError
@@ -116,9 +114,11 @@ class BaseKVStorage(StorageNameSpace):
embedding_func: EmbeddingFunc | None = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get value by id"""
raise NotImplementedError
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get values by ids"""
raise NotImplementedError
async def filter_keys(self, keys: set[str]) -> set[str]:
@@ -126,9 +126,11 @@ class BaseKVStorage(StorageNameSpace):
raise NotImplementedError
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Upsert data"""
raise NotImplementedError
async def drop(self) -> None:
"""Drop the storage"""
raise NotImplementedError
@@ -138,74 +140,62 @@ class BaseGraphStorage(StorageNameSpace):
"""Check if a node exists in the graph."""
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
"""Check if an edge exists in the graph."""
raise NotImplementedError
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
"""Get the degree of a node."""
raise NotImplementedError
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
"""Get the degree of an edge."""
raise NotImplementedError
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
"""Get a node by its id."""
raise NotImplementedError
async def get_node(self, node_id: str) -> dict[str, str] | None:
raise NotImplementedError
"""Get an edge by its source and target node ids."""
raise NotImplementedError
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get all edges connected to a node."""
raise NotImplementedError
"""Get all edges connected to a node."""
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
raise NotImplementedError
"""Upsert a node into the graph."""
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError
"""Upsert an edge into the graph."""
raise NotImplementedError
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
raise NotImplementedError
"""Delete a node from the graph."""
raise NotImplementedError
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
"""Embed nodes using an algorithm."""
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
"""Get all labels in the graph."""
raise NotImplementedError("Node embedding is not used in lightrag.")
"""Get all labels in the graph."""
async def get_all_labels(self) -> list[str]:
"""Get a knowledge graph of a node."""
raise NotImplementedError
"""Get a knowledge graph of a node."""
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
raise NotImplementedError

View File

@@ -5,7 +5,8 @@ import os
import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
@@ -15,6 +16,7 @@ if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
from lightrag.types import KnowledgeGraph
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
@@ -396,7 +398,7 @@ class AGEStorage(BaseGraphStorage):
)
return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`) RETURN n
@@ -454,17 +456,7 @@ class AGEStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Find all edges between nodes of two given labels
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
) -> dict[str, str] | None:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
@@ -488,7 +480,7 @@ class AGEStorage(BaseGraphStorage):
)
return result
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
@@ -526,7 +518,7 @@ class AGEStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((AGEQueryException,)),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the AGE database.
@@ -562,8 +554,8 @@ class AGEStorage(BaseGraphStorage):
retry=retry_if_exception_type((AGEQueryException,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their labels.
@@ -619,3 +611,15 @@ class AGEStorage(BaseGraphStorage):
yield connection
finally:
await self._driver.putconn(connection)
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -3,7 +3,9 @@ import inspect
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List
import numpy as np
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
@@ -15,6 +17,7 @@ from tenacity import (
wait_exponential,
)
from lightrag.types import KnowledgeGraph
from lightrag.utils import logger
from ..base import BaseGraphStorage
@@ -190,7 +193,7 @@ class GremlinStorage(BaseGraphStorage):
return result[0]["has_edge"]
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
@@ -252,17 +255,7 @@ class GremlinStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Find all edges between nodes of two given names
Args:
source_node_id (str): Name of the source nodes
target_node_id (str): Name of the target nodes
Returns:
dict|None: Dict of found edge properties, or None if not found
"""
) -> dict[str, str] | None:
entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g
@@ -286,11 +279,7 @@ class GremlinStorage(BaseGraphStorage):
)
return edge_properties
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
"""
Retrieves all edges (relationships) for a particular node identified by its name.
:return: List of tuples containing edge sources and targets
"""
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
node_name = GremlinStorage._fix_name(source_node_id)
query = f"""g
.E()
@@ -316,7 +305,7 @@ class GremlinStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the Gremlin graph.
@@ -357,8 +346,8 @@ class GremlinStorage(BaseGraphStorage):
retry=retry_if_exception_type((GremlinServerError,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their names.
@@ -397,3 +386,17 @@ class GremlinStorage(BaseGraphStorage):
async def _node2vec_embed(self):
print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -12,7 +12,7 @@ if not pm.is_installed("pymongo"):
if not pm.is_installed("motor"):
pm.install("motor")
from typing import Any, List, Tuple, Union
from typing import Any, List, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
@@ -448,7 +448,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""
Return the full node document (including "edges"), or None if missing.
"""
@@ -456,11 +456,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Return the first edge dict from source_node_id to target_node_id if it exists.
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
"""
) -> dict[str, str] | None:
pipeline = [
{"$match": {"_id": source_node_id}},
{
@@ -486,9 +482,7 @@ class MongoGraphStorage(BaseGraphStorage):
return e
return None
async def get_node_edges(
self, source_node_id: str
) -> Union[List[Tuple[str, str]], None]:
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Return a list of (source_id, target_id) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
@@ -522,7 +516,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def upsert_node(self, node_id: str, node_data: dict):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Insert or update a node document. If new, create an empty edges array.
"""
@@ -532,8 +526,8 @@ class MongoGraphStorage(BaseGraphStorage):
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict
):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data.
@@ -559,7 +553,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def delete_node(self, node_id: str):
async def delete_node(self, node_id: str) -> None:
"""
1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id.
@@ -576,7 +570,7 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
"""
Placeholder for demonstration, raises NotImplementedError.
"""
@@ -606,9 +600,7 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"])
return labels
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)

View File

@@ -3,7 +3,8 @@ import inspect
import os
import re
from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict
from typing import Any, List, Dict
import numpy as np
import pipmaster as pm
import configparser
@@ -191,7 +192,7 @@ class Neo4JStorage(BaseGraphStorage):
)
return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier.
Args:
@@ -252,17 +253,8 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""Find edge between two nodes identified by their labels.
) -> dict[str, str] | None:
Args:
source_node_id (str): Label of the source node
target_node_id (str): Label of the target node
Returns:
dict: Edge properties if found, with at least {"weight": 0.0}
None: If error occurs
"""
try:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
@@ -321,7 +313,7 @@ class Neo4JStorage(BaseGraphStorage):
# Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None}
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
node_label = source_node_id.strip('"')
"""
@@ -364,7 +356,7 @@ class Neo4JStorage(BaseGraphStorage):
)
),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the Neo4j database.
@@ -405,8 +397,8 @@ class Neo4JStorage(BaseGraphStorage):
),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their labels.
@@ -444,9 +436,7 @@ class Neo4JStorage(BaseGraphStorage):
async def _node2vec_embed(self):
print("Implemented but never called.")
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
@@ -603,7 +593,7 @@ class Neo4JStorage(BaseGraphStorage):
await traverse(label, 0)
return result
async def get_all_labels(self) -> List[str]:
async def get_all_labels(self) -> list[str]:
"""
Get all existing node labels in the database
Returns:
@@ -627,3 +617,11 @@ class Neo4JStorage(BaseGraphStorage):
async for record in result:
labels.append(record["label"])
return labels
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError

View File

@@ -51,11 +51,12 @@ Usage:
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
from typing import Any, cast
import networkx as nx
import numpy as np
from lightrag.types import KnowledgeGraph
from lightrag.utils import (
logger,
)
@@ -142,7 +143,7 @@ class NetworkXStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
@@ -151,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
@@ -162,35 +163,30 @@ class NetworkXStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
async def delete_node(self, node_id: str) -> None:
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -226,3 +222,9 @@ class NetworkXStorage(BaseGraphStorage):
for source, target in edges:
if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -13,6 +13,7 @@ if not pm.is_installed("oracledb"):
pm.install("oracledb")
from lightrag.types import KnowledgeGraph
import oracledb
from ..base import (
@@ -378,9 +379,7 @@ class OracleGraphStorage(BaseGraphStorage):
#################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""插入或更新节点"""
# print("go into upsert node method")
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id
entity_type = node_data["entity_type"]
description = node_data["description"]
@@ -413,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
) -> None:
"""插入或更新边"""
# print("go into upsert edge method")
source_name = source_node_id
@@ -453,8 +452,7 @@ class OracleGraphStorage(BaseGraphStorage):
await self.db.execute(merge_sql, data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""为节点生成向量"""
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -471,7 +469,7 @@ class OracleGraphStorage(BaseGraphStorage):
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
async def index_done_callback(self):
async def index_done_callback(self) -> None:
"""写入graphhml图文件"""
logger.info(
"Node and edge data had been saved into oracle db already, so nothing to do here!"
@@ -493,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage):
return False
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"]
params = {
"workspace": self.db.workspace,
@@ -510,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage):
return False
async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"]
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
@@ -528,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Edge degree",degree)
return degree
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"]
params = {"workspace": self.db.workspace, "node_id": node_id}
@@ -544,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""根据源和目标节点id获取边"""
) -> dict[str, str] | None:
SQL = SQL_TEMPLATES["get_edge"]
params = {
"workspace": self.db.workspace,
@@ -560,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
return None
async def get_node_edges(self, source_node_id: str):
"""根据节点id获取节点的所有边"""
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"]
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
@@ -597,6 +591,14 @@ class OracleGraphStorage(BaseGraphStorage):
if res:
return res
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError
N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -4,11 +4,13 @@ import json
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union
import numpy as np
import pipmaster as pm
from lightrag.types import KnowledgeGraph
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
@@ -835,7 +837,7 @@ class PGGraphStorage(BaseGraphStorage):
)
return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
@@ -890,17 +892,7 @@ class PGGraphStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Find all edges between nodes of two given labels
Args:
source_node_id (str): Label of the source nodes
target_node_id (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
) -> dict[str, str] | None:
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
@@ -924,7 +916,7 @@ class PGGraphStorage(BaseGraphStorage):
)
return result
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
@@ -972,14 +964,7 @@ class PGGraphStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((PGGraphQueryException,)),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
Upsert a node in the AGE database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
properties = node_data
@@ -1010,8 +995,8 @@ class PGGraphStorage(BaseGraphStorage):
retry=retry_if_exception_type((PGGraphQueryException,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their labels.
@@ -1053,6 +1038,19 @@ class PGGraphStorage(BaseGraphStorage):
async def _node2vec_embed(self):
print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError
NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",

View File

@@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from lightrag.types import KnowledgeGraph
from sqlalchemy import create_engine, text
from tqdm import tqdm
@@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
#################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id
entity_type = node_data["entity_type"]
description = node_data["description"]
@@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
) -> None:
source_name = source_node_id
target_name = target_node_id
weight = edge_data["weight"]
@@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage):
}
await self.db.execute(merge_sql, data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage):
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
return degree
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_node"]
param = {"name": node_id, "workspace": self.db.workspace}
return await self.db.query(sql, param)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_edge"]
param = {
"source_name": source_node_id,
@@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage):
}
return await self.db.query(sql, param)
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
sql = SQL_TEMPLATES["get_node_edges"]
param = {"source_name": source_node_id, "workspace": self.db.workspace}
res = await self.db.query(sql, param, multirows=True)
@@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage):
else:
return []
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError
N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",