Convert _ensure_label method from async to sync

This commit is contained in:
yangdx
2025-03-08 10:23:27 +08:00
parent 78f8d7a1ce
commit af26d65698

View File

@@ -176,11 +176,17 @@ class Neo4JStorage(BaseGraphStorage):
# Noe4J handles persistence automatically # Noe4J handles persistence automatically
pass pass
async def _ensure_label(self, label: str) -> str: def _ensure_label(self, label: str) -> str:
"""Ensure a label is valid """Ensure a label is valid
Args: Args:
label: The label to validate label: The label to validate
Returns:
str: The cleaned label
Raises:
ValueError: If label is empty after cleaning
""" """
clean_label = label.strip('"') clean_label = label.strip('"')
if not clean_label: if not clean_label:
@@ -201,7 +207,7 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If node_id is invalid ValueError: If node_id is invalid
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
entity_name_label = await self._ensure_label(node_id) entity_name_label = self._ensure_label(node_id)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
@@ -233,8 +239,8 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If either node_id is invalid ValueError: If either node_id is invalid
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
entity_name_label_source = await self._ensure_label(source_node_id) entity_name_label_source = self._ensure_label(source_node_id)
entity_name_label_target = await self._ensure_label(target_node_id) entity_name_label_target = self._ensure_label(target_node_id)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
@@ -269,7 +275,7 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If node_id is invalid ValueError: If node_id is invalid
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
entity_name_label = await self._ensure_label(node_id) entity_name_label = self._ensure_label(node_id)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
@@ -314,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage):
ValueError: If node_id is invalid ValueError: If node_id is invalid
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
entity_name_label = await self._ensure_label(node_id) entity_name_label = self._ensure_label(node_id)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
@@ -363,8 +369,8 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
int: Sum of the degrees of both nodes int: Sum of the degrees of both nodes
""" """
entity_name_label_source = await self._ensure_label(src_id) entity_name_label_source = self._ensure_label(src_id)
entity_name_label_target = await self._ensure_label(tgt_id) entity_name_label_target = self._ensure_label(tgt_id)
src_degree = await self.node_degree(entity_name_label_source) src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target) trg_degree = await self.node_degree(entity_name_label_target)
@@ -393,8 +399,8 @@ class Neo4JStorage(BaseGraphStorage):
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
try: try:
entity_name_label_source = await self._ensure_label(source_node_id) entity_name_label_source = self._ensure_label(source_node_id)
entity_name_label_target = await self._ensure_label(target_node_id) entity_name_label_target = self._ensure_label(target_node_id)
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
@@ -484,7 +490,7 @@ class Neo4JStorage(BaseGraphStorage):
Exception: If there is an error executing the query Exception: If there is an error executing the query
""" """
try: try:
node_label = await self._ensure_label(source_node_id) node_label = self._ensure_label(source_node_id)
query = f"""MATCH (n:`{node_label}`) query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
@@ -543,7 +549,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
label = await self._ensure_label(node_id) label = self._ensure_label(node_id)
properties = node_data properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction): async def _do_upsert(tx: AsyncManagedTransaction):
@@ -591,8 +597,8 @@ class Neo4JStorage(BaseGraphStorage):
Raises: Raises:
ValueError: If either source or target node does not exist ValueError: If either source or target node does not exist
""" """
source_label = await self._ensure_label(source_node_id) source_label = self._ensure_label(source_node_id)
target_label = await self._ensure_label(target_node_id) target_label = self._ensure_label(target_node_id)
edge_properties = edge_data edge_properties = edge_data
# Check if both nodes exist # Check if both nodes exist
@@ -966,7 +972,7 @@ class Neo4JStorage(BaseGraphStorage):
Args: Args:
node_id: The label of the node to delete node_id: The label of the node to delete
""" """
label = await self._ensure_label(node_id) label = self._ensure_label(node_id)
async def _do_delete(tx: AsyncManagedTransaction): async def _do_delete(tx: AsyncManagedTransaction):
query = f""" query = f"""
@@ -1024,8 +1030,8 @@ class Neo4JStorage(BaseGraphStorage):
edges: List of edges to be deleted, each edge is a (source, target) tuple edges: List of edges to be deleted, each edge is a (source, target) tuple
""" """
for source, target in edges: for source, target in edges:
source_label = await self._ensure_label(source) source_label = self._ensure_label(source)
target_label = await self._ensure_label(target) target_label = self._ensure_label(target)
async def _do_delete_edge(tx: AsyncManagedTransaction): async def _do_delete_edge(tx: AsyncManagedTransaction):
query = f""" query = f"""