fix linting
This commit is contained in:
@@ -567,62 +567,68 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Delete a node and all its related edges
|
||||
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to delete
|
||||
"""
|
||||
# First delete all edges related to this node
|
||||
await self.db.execute(SQL_TEMPLATES["delete_node_edges"],
|
||||
{"name": node_id, "workspace": self.db.workspace})
|
||||
|
||||
await self.db.execute(
|
||||
SQL_TEMPLATES["delete_node_edges"],
|
||||
{"name": node_id, "workspace": self.db.workspace},
|
||||
)
|
||||
|
||||
# Then delete the node itself
|
||||
await self.db.execute(SQL_TEMPLATES["delete_node"],
|
||||
{"name": node_id, "workspace": self.db.workspace})
|
||||
|
||||
logger.debug(f"Node {node_id} and its related edges have been deleted from the graph")
|
||||
|
||||
await self.db.execute(
|
||||
SQL_TEMPLATES["delete_node"],
|
||||
{"name": node_id, "workspace": self.db.workspace},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Node {node_id} and its related edges have been deleted from the graph"
|
||||
)
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""Get all entity types (labels) in the database
|
||||
|
||||
|
||||
Returns:
|
||||
List of labels sorted alphabetically
|
||||
"""
|
||||
result = await self.db.query(
|
||||
SQL_TEMPLATES["get_all_labels"],
|
||||
{"workspace": self.db.workspace},
|
||||
multirows=True
|
||||
SQL_TEMPLATES["get_all_labels"],
|
||||
{"workspace": self.db.workspace},
|
||||
multirows=True,
|
||||
)
|
||||
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
|
||||
# Extract all labels
|
||||
return [item["label"] for item in result]
|
||||
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get a connected subgraph of nodes matching the specified label
|
||||
Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000)
|
||||
|
||||
|
||||
Args:
|
||||
node_label: The node label to match
|
||||
max_depth: Maximum depth of the subgraph
|
||||
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
|
||||
|
||||
# Get matching nodes
|
||||
if node_label == "*":
|
||||
# Handle special case, get all nodes
|
||||
node_results = await self.db.query(
|
||||
SQL_TEMPLATES["get_all_nodes"],
|
||||
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
|
||||
multirows=True
|
||||
multirows=True,
|
||||
)
|
||||
else:
|
||||
# Get nodes matching the label
|
||||
@@ -630,84 +636,93 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
node_results = await self.db.query(
|
||||
SQL_TEMPLATES["get_matching_nodes"],
|
||||
{"workspace": self.db.workspace, "label_pattern": label_pattern},
|
||||
multirows=True
|
||||
multirows=True,
|
||||
)
|
||||
|
||||
|
||||
if not node_results:
|
||||
logger.warning(f"No nodes found matching label {node_label}")
|
||||
return result
|
||||
|
||||
|
||||
# Limit the number of returned nodes
|
||||
if len(node_results) > MAX_GRAPH_NODES:
|
||||
node_results = node_results[:MAX_GRAPH_NODES]
|
||||
|
||||
|
||||
# Extract node names for edge query
|
||||
node_names = [node["name"] for node in node_results]
|
||||
node_names_str = ",".join([f"'{name}'" for name in node_names])
|
||||
|
||||
|
||||
# Add nodes to result
|
||||
for node in node_results:
|
||||
node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]}
|
||||
node_properties = {
|
||||
k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]
|
||||
}
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node["name"],
|
||||
labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]],
|
||||
properties=node_properties
|
||||
labels=[node["entity_type"]]
|
||||
if node.get("entity_type")
|
||||
else [node["name"]],
|
||||
properties=node_properties,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Get related edges
|
||||
edge_results = await self.db.query(
|
||||
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
|
||||
{"workspace": self.db.workspace},
|
||||
multirows=True
|
||||
multirows=True,
|
||||
)
|
||||
|
||||
|
||||
if edge_results:
|
||||
# Add edges to result
|
||||
for edge in edge_results:
|
||||
# Only include edges related to selected nodes
|
||||
if edge["source_name"] in node_names and edge["target_name"] in node_names:
|
||||
if (
|
||||
edge["source_name"] in node_names
|
||||
and edge["target_name"] in node_names
|
||||
):
|
||||
edge_id = f"{edge['source_name']}-{edge['target_name']}"
|
||||
edge_properties = {k: v for k, v in edge.items()
|
||||
if k not in ["id", "source_name", "target_name"]}
|
||||
|
||||
edge_properties = {
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["id", "source_name", "target_name"]
|
||||
}
|
||||
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="RELATED",
|
||||
source=edge["source_name"],
|
||||
target=edge["target_name"],
|
||||
properties=edge_properties
|
||||
properties=edge_properties,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to delete
|
||||
"""
|
||||
for node_id in nodes:
|
||||
await self.delete_node(node_id)
|
||||
|
||||
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
|
||||
Args:
|
||||
edges: List of edges to delete, each edge is a (source, target) tuple
|
||||
"""
|
||||
for source, target in edges:
|
||||
await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], {
|
||||
"source": source,
|
||||
"target": target,
|
||||
"workspace": self.db.workspace
|
||||
})
|
||||
await self.db.execute(
|
||||
SQL_TEMPLATES["remove_multiple_edges"],
|
||||
{"source": source, "target": target, "workspace": self.db.workspace},
|
||||
)
|
||||
|
||||
|
||||
N_T = {
|
||||
@@ -919,26 +934,26 @@ SQL_TEMPLATES = {
|
||||
source_chunk_id = VALUES(source_chunk_id)
|
||||
""",
|
||||
"delete_node": """
|
||||
DELETE FROM LIGHTRAG_GRAPH_NODES
|
||||
DELETE FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE name = :name AND workspace = :workspace
|
||||
""",
|
||||
"delete_node_edges": """
|
||||
DELETE FROM LIGHTRAG_GRAPH_EDGES
|
||||
DELETE FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace
|
||||
""",
|
||||
"get_all_labels": """
|
||||
SELECT DISTINCT entity_type as label
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
SELECT DISTINCT entity_type as label
|
||||
FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE workspace = :workspace
|
||||
ORDER BY entity_type
|
||||
""",
|
||||
"get_matching_nodes": """
|
||||
SELECT * FROM LIGHTRAG_GRAPH_NODES
|
||||
SELECT * FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE name LIKE :label_pattern AND workspace = :workspace
|
||||
ORDER BY name
|
||||
""",
|
||||
"get_all_nodes": """
|
||||
SELECT * FROM LIGHTRAG_GRAPH_NODES
|
||||
SELECT * FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE workspace = :workspace
|
||||
ORDER BY name
|
||||
LIMIT :max_nodes
|
||||
@@ -952,5 +967,5 @@ SQL_TEMPLATES = {
|
||||
DELETE FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE (source_name = :source AND target_name = :target)
|
||||
AND workspace = :workspace
|
||||
"""
|
||||
""",
|
||||
}
|
||||
|
Reference in New Issue
Block a user