fix event loop conflict

This commit is contained in:
Ken Wiltshire
2024-11-06 11:18:14 -05:00
parent a73ef6ec5a
commit ac0b7fd775
9 changed files with 185 additions and 152 deletions

View File

@@ -1,17 +1,16 @@
import asyncio
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast, Tuple, List, Dict
import numpy as np
from typing import Any, Union, Tuple, List, Dict
import inspect
from lightrag.utils import load_json, logger, write_json
from ..base import (
BaseGraphStorage
from lightrag.utils import logger
from ..base import BaseGraphStorage
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
)
from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
from contextlib import asynccontextmanager
from tenacity import (
@@ -26,7 +25,7 @@ from tenacity import (
class Neo4JStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print ("no preloading of graph with neo4j in production")
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config)
@@ -35,7 +34,9 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None
def __post_init__(self):
@@ -43,59 +44,54 @@ class Neo4JStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print ("KG successfully indexed.")
print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
result = await session.run(query)
async with self._driver.session() as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
async with self._driver.session() as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query)
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
)
return single_result["edgeExists"]
def close(self):
self._driver.close()
def close(self):
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
@@ -103,54 +99,51 @@ class Neo4JStorage(BaseGraphStorage):
node = record["n"]
node_dict = dict(node)
logger.debug(
f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('\"')
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session() as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
"""
result = await session.run(query)
record = await result.single()
result = await session.run(query)
record = await result.single()
if record:
edge_count = record["totalEdgeCount"]
edge_count = record["totalEdgeCount"]
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
)
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
)
return edge_count
else:
else:
return None
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('\"')
entity_name_label_target = tgt_id.strip('\"')
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
)
f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
)
return degrees
async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('\"')
entity_name_label_target = target_node_id.strip('\"')
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
"""
Find all edges between nodes of two given labels
@@ -161,28 +154,30 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
async with self._driver.session() as session:
async with self._driver.session() as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
""".format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
result = await session.run(query)
""".format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
result = await session.run(query)
record = await result.single()
if record:
result = dict(record["edge_properties"])
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
)
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
else:
return None
async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
node_label = source_node_id.strip('\"')
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('"')
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
@@ -190,26 +185,37 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
async with self._driver.session() as session:
results = await session.run(query)
edges = []
async for record in results:
source_node = record['n']
connected_node = record['connected']
source_label = list(source_node.labels)[0] if source_node.labels else None
target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
source_node = record["n"]
connected_node = record["connected"]
source_label = (
list(source_node.labels)[0] if source_node.labels else None
)
target_label = (
list(connected_node.labels)[0]
if connected_node and connected_node.labels
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
return edges
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
@@ -219,7 +225,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('\"')
label = node_id.strip('"')
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
@@ -228,7 +234,9 @@ class Neo4JStorage(BaseGraphStorage):
SET n += $properties
"""
await tx.run(query, properties=properties)
logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
logger.debug(
f"Upserted node with label '{label}' and properties: {properties}"
)
try:
async with self._driver.session() as session:
@@ -236,13 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
)
),
)
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""
Upsert an edge and its properties between two nodes identified by their labels.
@@ -251,8 +267,8 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('\"')
target_node_label = target_node_id.strip('\"')
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
@@ -265,7 +281,9 @@ class Neo4JStorage(BaseGraphStorage):
RETURN r
"""
await tx.run(query, properties=edge_properties)
logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
)
try:
async with self._driver.session() as session:
@@ -273,6 +291,6 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def _node2vec_embed(self):
print ("Implemented but never called.")
print("Implemented but never called.")