Implement the missing methods.
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Any, Union, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
import sys
|
||||
from tenacity import (
|
||||
@@ -512,11 +512,66 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
# PG handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete vectors with specified IDs from the storage.
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
||||
return
|
||||
|
||||
ids_list = ",".join([f"'{id}'" for id in ids])
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
|
||||
|
||||
try:
|
||||
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
||||
logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete an entity by its name from the vector storage.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity to delete
|
||||
"""
|
||||
try:
|
||||
# Construct SQL to delete the entity
|
||||
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
|
||||
WHERE workspace=$1 AND entity_name=$2"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql,
|
||||
{"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete all relations associated with an entity.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity whose relations should be deleted
|
||||
"""
|
||||
try:
|
||||
# Delete relations where the entity is either the source or target
|
||||
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
|
||||
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql,
|
||||
{"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted relations for entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
||||
|
||||
|
||||
@final
|
||||
@@ -1086,20 +1141,192 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Delete a node from the graph.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node to delete.
|
||||
"""
|
||||
label = self._encode_graph_label(node_id.strip('"'))
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
DETACH DELETE n
|
||||
$$) AS (n agtype)""" % (self.graph_name, label)
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
except Exception as e:
|
||||
logger.error("Error during node deletion: {%s}", e)
|
||||
raise
|
||||
|
||||
async def remove_nodes(self, node_ids: list[str]) -> None:
|
||||
"""
|
||||
Remove multiple nodes from the graph.
|
||||
|
||||
Args:
|
||||
node_ids (list[str]): A list of node IDs to remove.
|
||||
"""
|
||||
encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids]
|
||||
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity)
|
||||
WHERE n.node_id IN [%s]
|
||||
DETACH DELETE n
|
||||
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
except Exception as e:
|
||||
logger.error("Error during node removal: {%s}", e)
|
||||
raise
|
||||
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
|
||||
"""
|
||||
Remove multiple edges from the graph.
|
||||
|
||||
Args:
|
||||
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
||||
"""
|
||||
encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges]
|
||||
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:Entity)-[r]->(b:Entity)
|
||||
WHERE [a.node_id, b.node_id] IN [%s]
|
||||
DELETE r
|
||||
$$) AS (r agtype)""" % (self.graph_name, edge_list)
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False)
|
||||
except Exception as e:
|
||||
logger.error("Error during edge removal: {%s}", e)
|
||||
raise
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all labels (node IDs) in the graph.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of all labels in the graph.
|
||||
"""
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity)
|
||||
RETURN DISTINCT n.node_id AS label
|
||||
$$) AS (label text)""" % self.graph_name
|
||||
|
||||
results = await self._query(query)
|
||||
labels = [self._decode_graph_label(result["label"]) for result in results]
|
||||
|
||||
return labels
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Generate node embeddings using the specified algorithm.
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
Args:
|
||||
algorithm (str): The name of the embedding algorithm to use.
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs.
|
||||
"""
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
|
||||
|
||||
embed_func = self._node_embed_algorithms[algorithm]
|
||||
return await embed_func()
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
|
||||
|
||||
Args:
|
||||
node_label (str): The label of the node to start from. If "*", the entire graph is returned.
|
||||
max_depth (int): The maximum depth to traverse from the starting node.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: The retrieved subgraph.
|
||||
"""
|
||||
MAX_GRAPH_NODES = 1000
|
||||
|
||||
if node_label == "*":
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity)
|
||||
OPTIONAL MATCH (n)-[r]->(m:Entity)
|
||||
RETURN n, r, m
|
||||
LIMIT %d
|
||||
$$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES)
|
||||
else:
|
||||
encoded_node_label = self._encode_graph_label(node_label.strip('"'))
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
||||
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
||||
LIMIT %d
|
||||
$$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES)
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
nodes = set()
|
||||
edges = []
|
||||
|
||||
for result in results:
|
||||
if node_label == "*":
|
||||
if result["n"]:
|
||||
node = result["n"]
|
||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
||||
if result["m"]:
|
||||
node = result["m"]
|
||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
||||
if result["r"]:
|
||||
edge = result["r"]
|
||||
src_id = self._decode_graph_label(edge["start_id"])
|
||||
tgt_id = self._decode_graph_label(edge["end_id"])
|
||||
edges.append((src_id, tgt_id))
|
||||
else:
|
||||
if result["nodes"]:
|
||||
for node in result["nodes"]:
|
||||
nodes.add(self._decode_graph_label(node["node_id"]))
|
||||
if result["relationships"]:
|
||||
for edge in result["relationships"]:
|
||||
src_id = self._decode_graph_label(edge["start_id"])
|
||||
tgt_id = self._decode_graph_label(edge["end_id"])
|
||||
edges.append((src_id, tgt_id))
|
||||
|
||||
kg = KnowledgeGraph(
|
||||
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
|
||||
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
|
||||
)
|
||||
|
||||
return kg
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all node labels in the graph
|
||||
Returns:
|
||||
[label1, label2, ...] # Alphabetically sorted label list
|
||||
"""
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity)
|
||||
RETURN DISTINCT n.node_id AS label
|
||||
ORDER BY label
|
||||
$$) AS (label agtype)""" % (self.graph_name)
|
||||
|
||||
try:
|
||||
results = await self._query(query)
|
||||
labels = []
|
||||
for record in results:
|
||||
if record["label"]:
|
||||
labels.append(self._decode_graph_label(record["label"]))
|
||||
return labels
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all labels: {str(e)}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
|
Reference in New Issue
Block a user