Refactor Neo4J storage initialization and cleanup

- Make initialization async
- Rename close() to finalize()
This commit is contained in:
yangdx
2025-04-02 10:45:21 +08:00
parent ab9d210fcd
commit 5f678adb71

View File

@@ -1,4 +1,3 @@
import asyncio
import inspect import inspect
import os import os
import re import re
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
exceptions as neo4jExceptions, exceptions as neo4jExceptions,
AsyncDriver, AsyncDriver,
AsyncManagedTransaction, AsyncManagedTransaction,
GraphDatabase,
) )
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self._driver = None self._driver = None
self._driver_lock = asyncio.Lock()
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get( USERNAME = os.environ.get(
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
), ),
) )
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
) )
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
@@ -98,22 +101,16 @@ class Neo4JStorage(BaseGraphStorage):
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
) )
# Try to connect to the database # Try to connect to the database and create it if it doesn't exist
with GraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
) as _sync_driver:
for database in (DATABASE, None): for database in (DATABASE, None):
self._DATABASE = database self._DATABASE = database
connected = False connected = False
try: try:
with _sync_driver.session(database=database) as session: async with self._driver.session(database=database) as session:
try: try:
session.run("MATCH (n) RETURN n LIMIT 0") result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}") logger.info(f"Connected to {database} at {URI}")
connected = True connected = True
except neo4jExceptions.ServiceUnavailable as e: except neo4jExceptions.ServiceUnavailable as e:
@@ -130,10 +127,11 @@ class Neo4JStorage(BaseGraphStorage):
f"{database} at {URI} not found. Try to create specified database.".capitalize() f"{database} at {URI} not found. Try to create specified database.".capitalize()
) )
try: try:
with _sync_driver.session() as session: async with self._driver.session() as session:
session.run( result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS" f"CREATE DATABASE `{database}` IF NOT EXISTS"
) )
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize()) logger.info(f"{database} at {URI} created".capitalize())
connected = True connected = True
except ( except (
@@ -143,9 +141,7 @@ class Neo4JStorage(BaseGraphStorage):
if ( if (
e.code e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand" == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or ( ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None: if database is not None:
logger.warning( logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
@@ -157,12 +153,7 @@ class Neo4JStorage(BaseGraphStorage):
if connected: if connected:
break break
def __post_init__(self): async def finalize(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
@@ -170,7 +161,7 @@ class Neo4JStorage(BaseGraphStorage):
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits""" """Ensure driver is closed when context manager exits"""
await self.close() await self.finalize()
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Noe4J handles persistence automatically # Noe4J handles persistence automatically