Implement the missing methods.

This commit is contained in:
zrguo
2025-03-04 15:50:53 +08:00
parent de9aeedad7
commit 3a2a636862
11 changed files with 1603 additions and 43 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, Union, final
import numpy as np
from lightrag.types import KnowledgeGraph
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
@@ -566,15 +566,148 @@ class TiDBGraphStorage(BaseGraphStorage):
pass
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
"""Delete a node and all its related edges
Args:
node_id: The ID of the node to delete
"""
# First delete all edges related to this node
await self.db.execute(SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace})
# Then delete the node itself
await self.db.execute(SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace})
logger.debug(f"Node {node_id} and its related edges have been deleted from the graph")
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
"""Get all entity types (labels) in the database
Returns:
List of labels sorted alphabetically
"""
result = await self.db.query(
SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace},
multirows=True
)
if not result:
return []
# Extract all labels
return [item["label"] for item in result]
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
"""
Get a connected subgraph of nodes matching the specified label
Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000)
Args:
node_label: The node label to match
max_depth: Maximum depth of the subgraph
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Get matching nodes
if node_label == "*":
# Handle special case, get all nodes
node_results = await self.db.query(
SQL_TEMPLATES["get_all_nodes"],
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
multirows=True
)
else:
# Get nodes matching the label
label_pattern = f"%{node_label}%"
node_results = await self.db.query(
SQL_TEMPLATES["get_matching_nodes"],
{"workspace": self.db.workspace, "label_pattern": label_pattern},
multirows=True
)
if not node_results:
logger.warning(f"No nodes found matching label {node_label}")
return result
# Limit the number of returned nodes
if len(node_results) > MAX_GRAPH_NODES:
node_results = node_results[:MAX_GRAPH_NODES]
# Extract node names for edge query
node_names = [node["name"] for node in node_results]
node_names_str = ",".join([f"'{name}'" for name in node_names])
# Add nodes to result
for node in node_results:
node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]}
result.nodes.append(
KnowledgeGraphNode(
id=node["name"],
labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]],
properties=node_properties
)
)
# Get related edges
edge_results = await self.db.query(
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
{"workspace": self.db.workspace},
multirows=True
)
if edge_results:
# Add edges to result
for edge in edge_results:
# Only include edges related to selected nodes
if edge["source_name"] in node_names and edge["target_name"] in node_names:
edge_id = f"{edge['source_name']}-{edge['target_name']}"
edge_properties = {k: v for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="RELATED",
source=edge["source_name"],
target=edge["target_name"],
properties=edge_properties
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to delete
"""
for node_id in nodes:
await self.delete_node(node_id)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to delete, each edge is a (source, target) tuple
"""
for source, target in edges:
await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], {
"source": source,
"target": target,
"workspace": self.db.workspace
})
N_T = {
@@ -785,4 +918,39 @@ SQL_TEMPLATES = {
weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description),
source_chunk_id = VALUES(source_chunk_id)
""",
"delete_node": """
DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE name = :name AND workspace = :workspace
""",
"delete_node_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace
""",
"get_all_labels": """
SELECT DISTINCT entity_type as label
FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY entity_type
""",
"get_matching_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE name LIKE :label_pattern AND workspace = :workspace
ORDER BY name
""",
"get_all_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY name
LIMIT :max_nodes
""",
"get_related_edges": """
SELECT * FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name IN (:node_names) OR target_name IN (:node_names))
AND workspace = :workspace
""",
"remove_multiple_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace
"""
}