Convert _ensure_label method from async to sync

This commit is contained in:
yangdx
2025-03-08 10:23:27 +08:00
parent 78f8d7a1ce
commit af26d65698

View File

@@ -176,11 +176,17 @@ class Neo4JStorage(BaseGraphStorage):
# Noe4J handles persistence automatically
pass
async def _ensure_label(self, label: str) -> str:
def _ensure_label(self, label: str) -> str:
"""Ensure a label is valid
Args:
label: The label to validate
Returns:
str: The cleaned label
Raises:
ValueError: If label is empty after cleaning
"""
clean_label = label.strip('"')
if not clean_label:
@@ -201,7 +207,7 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
entity_name_label = await self._ensure_label(node_id)
entity_name_label = self._ensure_label(node_id)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@@ -233,8 +239,8 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If either node_id is invalid
Exception: If there is an error executing the query
"""
entity_name_label_source = await self._ensure_label(source_node_id)
entity_name_label_target = await self._ensure_label(target_node_id)
entity_name_label_source = self._ensure_label(source_node_id)
entity_name_label_target = self._ensure_label(target_node_id)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
@@ -269,7 +275,7 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
entity_name_label = await self._ensure_label(node_id)
entity_name_label = self._ensure_label(node_id)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
@@ -314,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
entity_name_label = await self._ensure_label(node_id)
entity_name_label = self._ensure_label(node_id)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
@@ -363,8 +369,8 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
int: Sum of the degrees of both nodes
"""
entity_name_label_source = await self._ensure_label(src_id)
entity_name_label_target = await self._ensure_label(tgt_id)
entity_name_label_source = self._ensure_label(src_id)
entity_name_label_target = self._ensure_label(tgt_id)
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
@@ -393,8 +399,8 @@ class Neo4JStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
try:
entity_name_label_source = await self._ensure_label(source_node_id)
entity_name_label_target = await self._ensure_label(target_node_id)
entity_name_label_source = self._ensure_label(source_node_id)
entity_name_label_target = self._ensure_label(target_node_id)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
@@ -484,7 +490,7 @@ class Neo4JStorage(BaseGraphStorage):
Exception: If there is an error executing the query
"""
try:
node_label = await self._ensure_label(source_node_id)
node_label = self._ensure_label(source_node_id)
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
@@ -543,7 +549,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = await self._ensure_label(node_id)
label = self._ensure_label(node_id)
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
@@ -591,8 +597,8 @@ class Neo4JStorage(BaseGraphStorage):
Raises:
ValueError: If either source or target node does not exist
"""
source_label = await self._ensure_label(source_node_id)
target_label = await self._ensure_label(target_node_id)
source_label = self._ensure_label(source_node_id)
target_label = self._ensure_label(target_node_id)
edge_properties = edge_data
# Check if both nodes exist
@@ -966,7 +972,7 @@ class Neo4JStorage(BaseGraphStorage):
Args:
node_id: The label of the node to delete
"""
label = await self._ensure_label(node_id)
label = self._ensure_label(node_id)
async def _do_delete(tx: AsyncManagedTransaction):
query = f"""
@@ -1024,8 +1030,8 @@ class Neo4JStorage(BaseGraphStorage):
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
source_label = await self._ensure_label(source)
target_label = await self._ensure_label(target)
source_label = self._ensure_label(source)
target_label = self._ensure_label(target)
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = f"""