Refactor Neo4J storage initialization and cleanup
- Make initialization async - Rename close() to finalize()
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
||||
USERNAME = os.environ.get(
|
||||
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
),
|
||||
)
|
||||
DATABASE = os.environ.get(
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
||||
)
|
||||
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
@@ -98,22 +101,16 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
||||
)
|
||||
|
||||
# Try to connect to the database
|
||||
with GraphDatabase.driver(
|
||||
URI,
|
||||
auth=(USERNAME, PASSWORD),
|
||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||
connection_timeout=CONNECTION_TIMEOUT,
|
||||
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
||||
) as _sync_driver:
|
||||
# Try to connect to the database and create it if it doesn't exist
|
||||
for database in (DATABASE, None):
|
||||
self._DATABASE = database
|
||||
connected = False
|
||||
|
||||
try:
|
||||
with _sync_driver.session(database=database) as session:
|
||||
async with self._driver.session(database=database) as session:
|
||||
try:
|
||||
session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"Connected to {database} at {URI}")
|
||||
connected = True
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
@@ -130,10 +127,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||
)
|
||||
try:
|
||||
with _sync_driver.session() as session:
|
||||
session.run(
|
||||
async with self._driver.session() as session:
|
||||
result = await session.run(
|
||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||
)
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"{database} at {URI} created".capitalize())
|
||||
connected = True
|
||||
except (
|
||||
@@ -143,9 +141,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
) or (
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
):
|
||||
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
||||
if database is not None:
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||
@@ -157,12 +153,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if connected:
|
||||
break
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
async def finalize(self):
|
||||
"""Close the Neo4j driver and release all resources"""
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
@@ -170,7 +161,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Ensure driver is closed when context manager exits"""
|
||||
await self.close()
|
||||
await self.finalize()
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Noe4J handles persistence automatically
|
||||
|
Reference in New Issue
Block a user