Implement the missing methods.

This commit is contained in:
zrguo
2025-03-04 15:50:53 +08:00
parent de9aeedad7
commit 3a2a636862
11 changed files with 1603 additions and 43 deletions

View File

@@ -7,7 +7,7 @@ from typing import Any, Union, final
import numpy as np
import configparser
from lightrag.types import KnowledgeGraph
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import sys
from tenacity import (
@@ -512,11 +512,66 @@ class PGVectorStorage(BaseVectorStorage):
# PG handles persistence automatically
pass
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs from the storage.
Args:
ids: List of vector IDs to be deleted
"""
if not ids:
return
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
"""Delete an entity by its name from the vector storage.
Args:
entity_name: The name of the entity to delete
"""
try:
# Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute(
delete_sql,
{"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError
"""Delete all relations associated with an entity.
Args:
entity_name: The name of the entity whose relations should be deleted
"""
try:
# Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute(
delete_sql,
{"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
@final
@@ -1086,20 +1141,192 @@ class PGGraphStorage(BaseGraphStorage):
print("Implemented but never called.")
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
"""
Delete a node from the graph.
Args:
node_id (str): The ID of the node to delete.
"""
label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, label)
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during node deletion: {%s}", e)
raise
async def remove_nodes(self, node_ids: list[str]) -> None:
"""
Remove multiple nodes from the graph.
Args:
node_ids (list[str]): A list of node IDs to remove.
"""
encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids]
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
WHERE n.node_id IN [%s]
DETACH DELETE n
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during node removal: {%s}", e)
raise
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
"""
Remove multiple edges from the graph.
Args:
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
"""
encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges]
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity)-[r]->(b:Entity)
WHERE [a.node_id, b.node_id] IN [%s]
DELETE r
$$) AS (r agtype)""" % (self.graph_name, edge_list)
try:
await self._query(query, readonly=False)
except Exception as e:
logger.error("Error during edge removal: {%s}", e)
raise
async def get_all_labels(self) -> list[str]:
"""
Get all labels (node IDs) in the graph.
Returns:
list[str]: A list of all labels in the graph.
"""
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
$$) AS (label text)""" % self.graph_name
results = await self._query(query)
labels = [self._decode_graph_label(result["label"]) for result in results]
return labels
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
"""
Generate node embeddings using the specified algorithm.
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
Args:
algorithm (str): The name of the embedding algorithm to use.
Returns:
tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs.
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
embed_func = self._node_embed_algorithms[algorithm]
return await embed_func()
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
"""
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
Args:
node_label (str): The label of the node to start from. If "*", the entire graph is returned.
max_depth (int): The maximum depth to traverse from the starting node.
Returns:
KnowledgeGraph: The retrieved subgraph.
"""
MAX_GRAPH_NODES = 1000
if node_label == "*":
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]->(m:Entity)
RETURN n, r, m
LIMIT %d
$$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES)
else:
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES)
results = await self._query(query)
nodes = set()
edges = []
for result in results:
if node_label == "*":
if result["n"]:
node = result["n"]
nodes.add(self._decode_graph_label(node["node_id"]))
if result["m"]:
node = result["m"]
nodes.add(self._decode_graph_label(node["node_id"]))
if result["r"]:
edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
else:
if result["nodes"]:
for node in result["nodes"]:
nodes.add(self._decode_graph_label(node["node_id"]))
if result["relationships"]:
for edge in result["relationships"]:
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
kg = KnowledgeGraph(
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
)
return kg
async def get_all_labels(self) -> list[str]:
"""
Get all node labels in the graph
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity)
RETURN DISTINCT n.node_id AS label
ORDER BY label
$$) AS (label agtype)""" % (self.graph_name)
try:
results = await self._query(query)
labels = []
for record in results:
if record["label"]:
labels.append(self._decode_graph_label(record["label"]))
return labels
except Exception as e:
logger.error(f"Error getting all labels: {str(e)}")
return []
async def drop(self) -> None:
"""Drop the storage"""