Refactor Neo4J storage initialization and cleanup

- Make initialization async
- Rename close() to finalize()
This commit is contained in:
yangdx
2025-04-02 10:45:21 +08:00
parent ab9d210fcd
commit 5f678adb71

View File

@@ -1,4 +1,3 @@
import asyncio
import inspect import inspect
import os import os
import re import re
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
exceptions as neo4jExceptions, exceptions as neo4jExceptions,
AsyncDriver, AsyncDriver,
AsyncManagedTransaction, AsyncManagedTransaction,
GraphDatabase,
) )
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self._driver = None self._driver = None
self._driver_lock = asyncio.Lock()
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get( USERNAME = os.environ.get(
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
), ),
) )
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
) )
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
@@ -98,71 +101,59 @@ class Neo4JStorage(BaseGraphStorage):
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
) )
# Try to connect to the database # Try to connect to the database and create it if it doesn't exist
with GraphDatabase.driver( for database in (DATABASE, None):
URI, self._DATABASE = database
auth=(USERNAME, PASSWORD), connected = False
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
) as _sync_driver:
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try: try:
with _sync_driver.session(database=database) as session: async with self._driver.session(database=database) as session:
try: try:
session.run("MATCH (n) RETURN n LIMIT 0") result = await session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {database} at {URI}") await result.consume() # Ensure result is consumed
connected = True logger.info(f"Connected to {database} at {URI}")
except neo4jExceptions.ServiceUnavailable as e: connected = True
logger.error( except neo4jExceptions.ServiceUnavailable as e:
f"{database} at {URI} is not available".capitalize() logger.error(
) f"{database} at {URI} is not available".capitalize()
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
) )
try: raise e
with _sync_driver.session() as session: except neo4jExceptions.AuthError as e:
session.run( logger.error(f"Authentication failed for {database} at {URI}")
f"CREATE DATABASE `{database}` IF NOT EXISTS" raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
) )
logger.info(f"{database} at {URI} created".capitalize()) if database is None:
connected = True logger.error(f"Failed to create {database} at {URI}")
except ( raise e
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if connected: if connected:
break break
def __post_init__(self): async def finalize(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
@@ -170,7 +161,7 @@ class Neo4JStorage(BaseGraphStorage):
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits""" """Ensure driver is closed when context manager exits"""
await self.close() await self.finalize()
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Noe4J handles persistence automatically # Noe4J handles persistence automatically