Fix linting

This commit is contained in:
yangdx
2025-03-11 10:28:25 +08:00
parent aefd596990
commit 7fddabb441

View File

@@ -200,9 +200,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
return single_result["node_exists"] return single_result["node_exists"]
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error checking node existence for {node_id}: {str(e)}")
f"Error checking node existence for {node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
@@ -229,7 +227,11 @@ class Neo4JStorage(BaseGraphStorage):
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
) )
result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id) result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
single_result = await result.single() single_result = await result.single()
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"] return single_result["edgeExists"]
@@ -274,7 +276,11 @@ class Neo4JStorage(BaseGraphStorage):
node_dict = dict(node) node_dict = dict(node)
# Remove base label from labels list if it exists # Remove base label from labels list if it exists
if "labels" in node_dict: if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] node_dict["labels"] = [
label
for label in node_dict["labels"]
if label != "base"
]
logger.debug(f"Neo4j query node {query} return: {node_dict}") logger.debug(f"Neo4j query node {query} return: {node_dict}")
return node_dict return node_dict
return None return None
@@ -313,20 +319,18 @@ class Neo4JStorage(BaseGraphStorage):
record = await result.single() record = await result.single()
if not record: if not record:
logger.warning( logger.warning(f"No node found with label '{node_id}'")
f"No node found with label '{node_id}'"
)
return 0 return 0
degree = record["degree"] degree = record["degree"]
logger.debug("Neo4j query node degree for {node_id} return: {degree}") logger.debug(
"Neo4j query node degree for {node_id} return: {degree}"
)
return degree return degree
finally: finally:
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error getting node degree for {node_id}: {str(e)}")
f"Error getting node degree for {node_id}: {str(e)}"
)
raise raise
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -373,7 +377,11 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
""" """
result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
try: try:
records = await result.fetch(2) records = await result.fetch(2)
@@ -471,10 +479,14 @@ class Neo4JStorage(BaseGraphStorage):
continue continue
source_label = ( source_label = (
source_node.get("entity_id") if source_node.get("entity_id") else None source_node.get("entity_id")
if source_node.get("entity_id")
else None
) )
target_label = ( target_label = (
connected_node.get("entity_id") if connected_node.get("entity_id") else None connected_node.get("entity_id")
if connected_node.get("entity_id")
else None
) )
if source_label and target_label: if source_label and target_label:
@@ -483,7 +495,9 @@ class Neo4JStorage(BaseGraphStorage):
await results.consume() # Ensure results are consumed await results.consume() # Ensure results are consumed
return edges return edges
except Exception as e: except Exception as e:
logger.error(f"Error getting edges for node {source_node_id}: {str(e)}") logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error await results.consume() # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
@@ -520,11 +534,14 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = """ query = (
"""
MERGE (n:base {entity_id: $properties.entity_id}) MERGE (n:base {entity_id: $properties.entity_id})
SET n += $properties SET n += $properties
SET n:`%s` SET n:`%s`
""" % entity_type """
% entity_type
)
result = await tx.run(query, properties=properties) result = await tx.run(query, properties=properties)
logger.debug( logger.debug(
f"Upserted node with entity_id '{entity_id}' and properties: {properties}" f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
@@ -548,7 +565,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
), ),
) )
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -728,7 +744,11 @@ class Neo4JStorage(BaseGraphStorage):
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=f"{node_id}", id=f"{node_id}",
labels=[label for label in node.labels if label != "base"], labels=[
label
for label in node.labels
if label != "base"
],
properties=dict(node), properties=dict(node),
) )
) )
@@ -767,7 +787,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.warning( logger.warning(
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching" "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
) )
return await self._robust_fallback(node_label, max_depth, min_degree) return await self._robust_fallback(
node_label, max_depth, min_degree
)
return result return result
@@ -843,7 +865,9 @@ class Neo4JStorage(BaseGraphStorage):
# Create KnowledgeGraphNode for target # Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode( target_node = KnowledgeGraphNode(
id=f"{target_id}", id=f"{target_id}",
labels=[label for label in b_node.labels if label != "base"], labels=[
label for label in b_node.labels if label != "base"
],
properties=dict(b_node.properties), properties=dict(b_node.properties),
) )
@@ -883,7 +907,9 @@ class Neo4JStorage(BaseGraphStorage):
# Create initial KnowledgeGraphNode # Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode( start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}", id=f"{node_record['n'].get('entity_id')}",
labels=[label for label in node_record["n"].labels if label != "base"], labels=[
label for label in node_record["n"].labels if label != "base"
],
properties=dict(node_record["n"].properties), properties=dict(node_record["n"].properties),
) )
finally: finally:
@@ -942,6 +968,7 @@ class Neo4JStorage(BaseGraphStorage):
Args: Args:
node_id: The label of the node to delete node_id: The label of the node to delete
""" """
async def _do_delete(tx: AsyncManagedTransaction): async def _do_delete(tx: AsyncManagedTransaction):
query = """ query = """
MATCH (n:base {entity_id: $entity_id}) MATCH (n:base {entity_id: $entity_id})
@@ -998,12 +1025,15 @@ class Neo4JStorage(BaseGraphStorage):
edges: List of edges to be deleted, each edge is a (source, target) tuple edges: List of edges to be deleted, each edge is a (source, target) tuple
""" """
for source, target in edges: for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction): async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """ query = """
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
DELETE r DELETE r
""" """
result = await tx.run(query, source_entity_id=source, target_entity_id=target) result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'") logger.debug(f"Deleted edge from '{source}' to '{target}'")
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed