Fix linting
This commit is contained in:
@@ -195,14 +195,12 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
) as session:
|
||||
try:
|
||||
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
|
||||
result = await session.run(query, entity_id = node_id)
|
||||
result = await session.run(query, entity_id=node_id)
|
||||
single_result = await result.single()
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
return single_result["node_exists"]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking node existence for {node_id}: {str(e)}"
|
||||
)
|
||||
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
|
||||
await result.consume() # Ensure results are consumed even on error
|
||||
raise
|
||||
|
||||
@@ -229,7 +227,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
)
|
||||
result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id)
|
||||
result = await session.run(
|
||||
query,
|
||||
source_entity_id=source_node_id,
|
||||
target_entity_id=target_node_id,
|
||||
)
|
||||
single_result = await result.single()
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
return single_result["edgeExists"]
|
||||
@@ -274,7 +276,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
node_dict = dict(node)
|
||||
# Remove base label from labels list if it exists
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
node_dict["labels"] = [
|
||||
label
|
||||
for label in node_dict["labels"]
|
||||
if label != "base"
|
||||
]
|
||||
logger.debug(f"Neo4j query node {query} return: {node_dict}")
|
||||
return node_dict
|
||||
return None
|
||||
@@ -308,25 +314,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
RETURN COUNT(r) AS degree
|
||||
"""
|
||||
result = await session.run(query, entity_id = node_id)
|
||||
result = await session.run(query, entity_id=node_id)
|
||||
try:
|
||||
record = await result.single()
|
||||
|
||||
if not record:
|
||||
logger.warning(
|
||||
f"No node found with label '{node_id}'"
|
||||
)
|
||||
logger.warning(f"No node found with label '{node_id}'")
|
||||
return 0
|
||||
|
||||
degree = record["degree"]
|
||||
logger.debug("Neo4j query node degree for {node_id} return: {degree}")
|
||||
logger.debug(
|
||||
"Neo4j query node degree for {node_id} return: {degree}"
|
||||
)
|
||||
return degree
|
||||
finally:
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting node degree for {node_id}: {str(e)}"
|
||||
)
|
||||
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
@@ -373,7 +377,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
|
||||
RETURN properties(r) as edge_properties
|
||||
"""
|
||||
result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id)
|
||||
result = await session.run(
|
||||
query,
|
||||
source_entity_id=source_node_id,
|
||||
target_entity_id=target_node_id,
|
||||
)
|
||||
try:
|
||||
records = await result.fetch(2)
|
||||
|
||||
@@ -471,10 +479,14 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
continue
|
||||
|
||||
source_label = (
|
||||
source_node.get("entity_id") if source_node.get("entity_id") else None
|
||||
source_node.get("entity_id")
|
||||
if source_node.get("entity_id")
|
||||
else None
|
||||
)
|
||||
target_label = (
|
||||
connected_node.get("entity_id") if connected_node.get("entity_id") else None
|
||||
connected_node.get("entity_id")
|
||||
if connected_node.get("entity_id")
|
||||
else None
|
||||
)
|
||||
|
||||
if source_label and target_label:
|
||||
@@ -483,7 +495,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
await results.consume() # Ensure results are consumed
|
||||
return edges
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting edges for node {source_node_id}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting edges for node {source_node_id}: {str(e)}"
|
||||
)
|
||||
await results.consume() # Ensure results are consumed even on error
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -520,11 +534,14 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = """
|
||||
query = (
|
||||
"""
|
||||
MERGE (n:base {entity_id: $properties.entity_id})
|
||||
SET n += $properties
|
||||
SET n:`%s`
|
||||
""" % entity_type
|
||||
"""
|
||||
% entity_type
|
||||
)
|
||||
result = await tx.run(query, properties=properties)
|
||||
logger.debug(
|
||||
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
|
||||
@@ -548,7 +565,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@@ -728,7 +744,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[label for label in node.labels if label != "base"],
|
||||
labels=[
|
||||
label
|
||||
for label in node.labels
|
||||
if label != "base"
|
||||
],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
@@ -767,7 +787,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.warning(
|
||||
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
||||
)
|
||||
return await self._robust_fallback(node_label, max_depth, min_degree)
|
||||
return await self._robust_fallback(
|
||||
node_label, max_depth, min_degree
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -843,7 +865,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=[label for label in b_node.labels if label != "base"],
|
||||
labels=[
|
||||
label for label in b_node.labels if label != "base"
|
||||
],
|
||||
properties=dict(b_node.properties),
|
||||
)
|
||||
|
||||
@@ -883,7 +907,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=f"{node_record['n'].get('entity_id')}",
|
||||
labels=[label for label in node_record["n"].labels if label != "base"],
|
||||
labels=[
|
||||
label for label in node_record["n"].labels if label != "base"
|
||||
],
|
||||
properties=dict(node_record["n"].properties),
|
||||
)
|
||||
finally:
|
||||
@@ -942,6 +968,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
Args:
|
||||
node_id: The label of the node to delete
|
||||
"""
|
||||
|
||||
async def _do_delete(tx: AsyncManagedTransaction):
|
||||
query = """
|
||||
MATCH (n:base {entity_id: $entity_id})
|
||||
@@ -998,12 +1025,15 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
for source, target in edges:
|
||||
|
||||
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
||||
query = """
|
||||
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
|
||||
DELETE r
|
||||
"""
|
||||
result = await tx.run(query, source_entity_id=source, target_entity_id=target)
|
||||
result = await tx.run(
|
||||
query, source_entity_id=source, target_entity_id=target
|
||||
)
|
||||
logger.debug(f"Deleted edge from '{source}' to '{target}'")
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
|
Reference in New Issue
Block a user