Refactor Neo4J storage initialization and cleanup
- Make initialization async - Rename close() to finalize()
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
|
|||||||
exceptions as neo4jExceptions,
|
exceptions as neo4jExceptions,
|
||||||
AsyncDriver,
|
AsyncDriver,
|
||||||
AsyncManagedTransaction,
|
AsyncManagedTransaction,
|
||||||
GraphDatabase,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
self._driver = None
|
self._driver = None
|
||||||
self._driver_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._node_embed_algorithms = {
|
||||||
|
"node2vec": self._node2vec_embed,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
||||||
USERNAME = os.environ.get(
|
USERNAME = os.environ.get(
|
||||||
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||||
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
DATABASE = os.environ.get(
|
DATABASE = os.environ.get(
|
||||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||||
@@ -98,22 +101,16 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to connect to the database
|
# Try to connect to the database and create it if it doesn't exist
|
||||||
with GraphDatabase.driver(
|
|
||||||
URI,
|
|
||||||
auth=(USERNAME, PASSWORD),
|
|
||||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
|
||||||
connection_timeout=CONNECTION_TIMEOUT,
|
|
||||||
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
|
||||||
) as _sync_driver:
|
|
||||||
for database in (DATABASE, None):
|
for database in (DATABASE, None):
|
||||||
self._DATABASE = database
|
self._DATABASE = database
|
||||||
connected = False
|
connected = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with _sync_driver.session(database=database) as session:
|
async with self._driver.session(database=database) as session:
|
||||||
try:
|
try:
|
||||||
session.run("MATCH (n) RETURN n LIMIT 0")
|
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
||||||
|
await result.consume() # Ensure result is consumed
|
||||||
logger.info(f"Connected to {database} at {URI}")
|
logger.info(f"Connected to {database} at {URI}")
|
||||||
connected = True
|
connected = True
|
||||||
except neo4jExceptions.ServiceUnavailable as e:
|
except neo4jExceptions.ServiceUnavailable as e:
|
||||||
@@ -130,10 +127,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with _sync_driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
session.run(
|
result = await session.run(
|
||||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||||
)
|
)
|
||||||
|
await result.consume() # Ensure result is consumed
|
||||||
logger.info(f"{database} at {URI} created".capitalize())
|
logger.info(f"{database} at {URI} created".capitalize())
|
||||||
connected = True
|
connected = True
|
||||||
except (
|
except (
|
||||||
@@ -143,9 +141,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
if (
|
if (
|
||||||
e.code
|
e.code
|
||||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||||
) or (
|
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
||||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
|
||||||
):
|
|
||||||
if database is not None:
|
if database is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||||
@@ -157,12 +153,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
if connected:
|
if connected:
|
||||||
break
|
break
|
||||||
|
|
||||||
def __post_init__(self):
|
async def finalize(self):
|
||||||
self._node_embed_algorithms = {
|
|
||||||
"node2vec": self._node2vec_embed,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Close the Neo4j driver and release all resources"""
|
"""Close the Neo4j driver and release all resources"""
|
||||||
if self._driver:
|
if self._driver:
|
||||||
await self._driver.close()
|
await self._driver.close()
|
||||||
@@ -170,7 +161,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
"""Ensure driver is closed when context manager exits"""
|
"""Ensure driver is closed when context manager exits"""
|
||||||
await self.close()
|
await self.finalize()
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# Noe4J handles persistence automatically
|
# Noe4J handles persistence automatically
|
||||||
|
Reference in New Issue
Block a user