diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 4ee88da2..f9e56032 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,4 +1,3 @@ -import asyncio import inspect import os import re @@ -29,7 +28,6 @@ from neo4j import ( # type: ignore exceptions as neo4jExceptions, AsyncDriver, AsyncManagedTransaction, - GraphDatabase, ) config = configparser.ConfigParser() @@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage): embedding_func=embedding_func, ) self._driver = None - self._driver_lock = asyncio.Lock() + def __post_init__(self): + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def initialize(self): URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) USERNAME = os.environ.get( "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) @@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage): ), ) DATABASE = os.environ.get( - "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) + "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) ) self._driver: AsyncDriver = AsyncGraphDatabase.driver( @@ -98,71 +101,59 @@ class Neo4JStorage(BaseGraphStorage): max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, ) - # Try to connect to the database - with GraphDatabase.driver( - URI, - auth=(USERNAME, PASSWORD), - max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, - connection_timeout=CONNECTION_TIMEOUT, - connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, - ) as _sync_driver: - for database in (DATABASE, None): - self._DATABASE = database - connected = False + # Try to connect to the database and create it if it doesn't exist + for database in (DATABASE, None): + self._DATABASE = database + connected = False - try: - with _sync_driver.session(database=database) as session: - try: - session.run("MATCH (n) RETURN n LIMIT 0") - logger.info(f"Connected to {database} at {URI}") - connected = True - except neo4jExceptions.ServiceUnavailable as e: - logger.error( - f"{database} at {URI} is not available".capitalize() - ) - raise e - except neo4jExceptions.AuthError as e: - logger.error(f"Authentication failed for {database} at {URI}") - raise e - except neo4jExceptions.ClientError as e: - if e.code == "Neo.ClientError.Database.DatabaseNotFound": - logger.info( - f"{database} at {URI} not found. Try to create specified database.".capitalize() + try: + async with self._driver.session(database=database) as session: + try: + result = await session.run("MATCH (n) RETURN n LIMIT 0") + await result.consume() # Ensure result is consumed + logger.info(f"Connected to {database} at {URI}") + connected = True + except neo4jExceptions.ServiceUnavailable as e: + logger.error( + f"{database} at {URI} is not available".capitalize() ) - try: - with _sync_driver.session() as session: - session.run( - f"CREATE DATABASE `{database}` IF NOT EXISTS" + raise e + except neo4jExceptions.AuthError as e: + logger.error(f"Authentication failed for {database} at {URI}") + raise e + except neo4jExceptions.ClientError as e: + if e.code == "Neo.ClientError.Database.DatabaseNotFound": + logger.info( + f"{database} at {URI} not found. Try to create specified database.".capitalize() + ) + try: + async with self._driver.session() as session: + result = await session.run( + f"CREATE DATABASE `{database}` IF NOT EXISTS" + ) + await result.consume() # Ensure result is consumed + logger.info(f"{database} at {URI} created".capitalize()) + connected = True + except ( + neo4jExceptions.ClientError, + neo4jExceptions.DatabaseError, + ) as e: + if ( + e.code + == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" + ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): + if database is not None: + logger.warning( + "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." ) - logger.info(f"{database} at {URI} created".capitalize()) - connected = True - except ( - neo4jExceptions.ClientError, - neo4jExceptions.DatabaseError, - ) as e: - if ( - e.code - == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" - ) or ( - e.code == "Neo.DatabaseError.Statement.ExecutionFailed" - ): - if database is not None: - logger.warning( - "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." - ) - if database is None: - logger.error(f"Failed to create {database} at {URI}") - raise e + if database is None: + logger.error(f"Failed to create {database} at {URI}") + raise e - if connected: - break + if connected: + break - def __post_init__(self): - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } - - async def close(self): + async def finalize(self): """Close the Neo4j driver and release all resources""" if self._driver: await self._driver.close() @@ -170,7 +161,7 @@ class Neo4JStorage(BaseGraphStorage): async def __aexit__(self, exc_type, exc, tb): """Ensure driver is closed when context manager exits""" - await self.close() + await self.finalize() async def index_done_callback(self) -> None: # Noe4J handles persistence automatically