fix linting

This commit is contained in:
zrguo
2025-03-04 15:53:20 +08:00
parent 3a2a636862
commit 81568f3bad
11 changed files with 394 additions and 327 deletions

View File

@@ -567,62 +567,68 @@ class TiDBGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None:
"""Delete a node and all its related edges
Args:
node_id: The ID of the node to delete
"""
# First delete all edges related to this node
await self.db.execute(SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace})
await self.db.execute(
SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace},
)
# Then delete the node itself
await self.db.execute(SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace})
logger.debug(f"Node {node_id} and its related edges have been deleted from the graph")
await self.db.execute(
SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace},
)
logger.debug(
f"Node {node_id} and its related edges have been deleted from the graph"
)
async def get_all_labels(self) -> list[str]:
"""Get all entity types (labels) in the database
Returns:
List of labels sorted alphabetically
"""
result = await self.db.query(
SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace},
multirows=True
SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace},
multirows=True,
)
if not result:
return []
# Extract all labels
return [item["label"] for item in result]
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get a connected subgraph of nodes matching the specified label
Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000)
Args:
node_label: The node label to match
max_depth: Maximum depth of the subgraph
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Get matching nodes
if node_label == "*":
# Handle special case, get all nodes
node_results = await self.db.query(
SQL_TEMPLATES["get_all_nodes"],
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
multirows=True
multirows=True,
)
else:
# Get nodes matching the label
@@ -630,84 +636,93 @@ class TiDBGraphStorage(BaseGraphStorage):
node_results = await self.db.query(
SQL_TEMPLATES["get_matching_nodes"],
{"workspace": self.db.workspace, "label_pattern": label_pattern},
multirows=True
multirows=True,
)
if not node_results:
logger.warning(f"No nodes found matching label {node_label}")
return result
# Limit the number of returned nodes
if len(node_results) > MAX_GRAPH_NODES:
node_results = node_results[:MAX_GRAPH_NODES]
# Extract node names for edge query
node_names = [node["name"] for node in node_results]
node_names_str = ",".join([f"'{name}'" for name in node_names])
# Add nodes to result
for node in node_results:
node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]}
node_properties = {
k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]
}
result.nodes.append(
KnowledgeGraphNode(
id=node["name"],
labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]],
properties=node_properties
labels=[node["entity_type"]]
if node.get("entity_type")
else [node["name"]],
properties=node_properties,
)
)
# Get related edges
edge_results = await self.db.query(
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
{"workspace": self.db.workspace},
multirows=True
multirows=True,
)
if edge_results:
# Add edges to result
for edge in edge_results:
# Only include edges related to selected nodes
if edge["source_name"] in node_names and edge["target_name"] in node_names:
if (
edge["source_name"] in node_names
and edge["target_name"] in node_names
):
edge_id = f"{edge['source_name']}-{edge['target_name']}"
edge_properties = {k: v for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]}
edge_properties = {
k: v
for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]
}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="RELATED",
source=edge["source_name"],
target=edge["target_name"],
properties=edge_properties
properties=edge_properties,
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to delete
"""
for node_id in nodes:
await self.delete_node(node_id)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to delete, each edge is a (source, target) tuple
"""
for source, target in edges:
await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], {
"source": source,
"target": target,
"workspace": self.db.workspace
})
await self.db.execute(
SQL_TEMPLATES["remove_multiple_edges"],
{"source": source, "target": target, "workspace": self.db.workspace},
)
N_T = {
@@ -919,26 +934,26 @@ SQL_TEMPLATES = {
source_chunk_id = VALUES(source_chunk_id)
""",
"delete_node": """
DELETE FROM LIGHTRAG_GRAPH_NODES
DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE name = :name AND workspace = :workspace
""",
"delete_node_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace
""",
"get_all_labels": """
SELECT DISTINCT entity_type as label
FROM LIGHTRAG_GRAPH_NODES
SELECT DISTINCT entity_type as label
FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY entity_type
""",
"get_matching_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE name LIKE :label_pattern AND workspace = :workspace
ORDER BY name
""",
"get_all_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY name
LIMIT :max_nodes
@@ -952,5 +967,5 @@ SQL_TEMPLATES = {
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace
"""
""",
}