Fix linting

This commit is contained in:
yangdx
2025-03-11 10:28:25 +08:00
parent aefd596990
commit 7fddabb441

View File

@@ -195,14 +195,12 @@ class Neo4JStorage(BaseGraphStorage):
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id = node_id)
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["node_exists"]
except Exception as e:
logger.error(
f"Error checking node existence for {node_id}: {str(e)}"
)
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure results are consumed even on error
raise
@@ -229,7 +227,11 @@ class Neo4JStorage(BaseGraphStorage):
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id)
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"]
@@ -274,7 +276,11 @@ class Neo4JStorage(BaseGraphStorage):
node_dict = dict(node)
# Remove base label from labels list if it exists
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
node_dict["labels"] = [
label
for label in node_dict["labels"]
if label != "base"
]
logger.debug(f"Neo4j query node {query} return: {node_dict}")
return node_dict
return None
@@ -308,25 +314,23 @@ class Neo4JStorage(BaseGraphStorage):
OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree
"""
result = await session.run(query, entity_id = node_id)
result = await session.run(query, entity_id=node_id)
try:
record = await result.single()
if not record:
logger.warning(
f"No node found with label '{node_id}'"
)
logger.warning(f"No node found with label '{node_id}'")
return 0
degree = record["degree"]
logger.debug("Neo4j query node degree for {node_id} return: {degree}")
logger.debug(
"Neo4j query node degree for {node_id} return: {degree}"
)
return degree
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(
f"Error getting node degree for {node_id}: {str(e)}"
)
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -373,7 +377,11 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
RETURN properties(r) as edge_properties
"""
result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id)
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
try:
records = await result.fetch(2)
@@ -471,10 +479,14 @@ class Neo4JStorage(BaseGraphStorage):
continue
source_label = (
source_node.get("entity_id") if source_node.get("entity_id") else None
source_node.get("entity_id")
if source_node.get("entity_id")
else None
)
target_label = (
connected_node.get("entity_id") if connected_node.get("entity_id") else None
connected_node.get("entity_id")
if connected_node.get("entity_id")
else None
)
if source_label and target_label:
@@ -483,7 +495,9 @@ class Neo4JStorage(BaseGraphStorage):
await results.consume() # Ensure results are consumed
return edges
except Exception as e:
logger.error(f"Error getting edges for node {source_node_id}: {str(e)}")
logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
@@ -520,11 +534,14 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
query = """
query = (
"""
MERGE (n:base {entity_id: $properties.entity_id})
SET n += $properties
SET n:`%s`
""" % entity_type
"""
% entity_type
)
result = await tx.run(query, properties=properties)
logger.debug(
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
@@ -548,7 +565,6 @@ class Neo4JStorage(BaseGraphStorage):
)
),
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -728,7 +744,11 @@ class Neo4JStorage(BaseGraphStorage):
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[label for label in node.labels if label != "base"],
labels=[
label
for label in node.labels
if label != "base"
],
properties=dict(node),
)
)
@@ -767,7 +787,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.warning(
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
)
return await self._robust_fallback(node_label, max_depth, min_degree)
return await self._robust_fallback(
node_label, max_depth, min_degree
)
return result
@@ -843,7 +865,9 @@ class Neo4JStorage(BaseGraphStorage):
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=[label for label in b_node.labels if label != "base"],
labels=[
label for label in b_node.labels if label != "base"
],
properties=dict(b_node.properties),
)
@@ -883,7 +907,9 @@ class Neo4JStorage(BaseGraphStorage):
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}",
labels=[label for label in node_record["n"].labels if label != "base"],
labels=[
label for label in node_record["n"].labels if label != "base"
],
properties=dict(node_record["n"].properties),
)
finally:
@@ -942,6 +968,7 @@ class Neo4JStorage(BaseGraphStorage):
Args:
node_id: The label of the node to delete
"""
async def _do_delete(tx: AsyncManagedTransaction):
query = """
MATCH (n:base {entity_id: $entity_id})
@@ -998,12 +1025,15 @@ class Neo4JStorage(BaseGraphStorage):
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
DELETE r
"""
result = await tx.run(query, source_entity_id=source, target_entity_id=target)
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
await result.consume() # Ensure result is fully consumed