Refactor Neo4J storage initialization and cleanup

- Make initialization async
- Rename close() to finalize()
This commit is contained in:
yangdx
2025-04-02 10:45:21 +08:00
parent ab9d210fcd
commit 5f678adb71

View File

@@ -1,4 +1,3 @@
import asyncio
import inspect
import os
import re
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
config = configparser.ConfigParser()
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
embedding_func=embedding_func,
)
self._driver = None
self._driver_lock = asyncio.Lock()
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get(
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
),
)
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
@@ -98,71 +101,59 @@ class Neo4JStorage(BaseGraphStorage):
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
)
# Try to connect to the database
with GraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
) as _sync_driver:
for database in (DATABASE, None):
self._DATABASE = database
connected = False
# Try to connect to the database and create it if it doesn't exist
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try:
with _sync_driver.session(database=database) as session:
try:
session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
try:
async with self._driver.session(database=database) as session:
try:
result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
)
try:
with _sync_driver.session() as session:
session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if connected:
break
if connected:
break
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
async def finalize(self):
"""Close the Neo4j driver and release all resources"""
if self._driver:
await self._driver.close()
@@ -170,7 +161,7 @@ class Neo4JStorage(BaseGraphStorage):
async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits"""
await self.close()
await self.finalize()
async def index_done_callback(self) -> None:
# Noe4J handles persistence automatically