diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 5db3016f..884fcb40 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,18 +1,16 @@ import asyncio +import inspect import os from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict -import inspect -from lightrag.utils import logger -from ..base import BaseGraphStorage + from neo4j import ( AsyncGraphDatabase, exceptions as neo4jExceptions, AsyncDriver, AsyncManagedTransaction, + GraphDatabase, ) - - from tenacity import ( retry, stop_after_attempt, @@ -20,6 +18,9 @@ from tenacity import ( retry_if_exception_type, ) +from lightrag.utils import logger +from ..base import BaseGraphStorage + @dataclass class Neo4JStorage(BaseGraphStorage): @@ -38,10 +39,47 @@ class Neo4JStorage(BaseGraphStorage): URI = os.environ["NEO4J_URI"] USERNAME = os.environ["NEO4J_USERNAME"] PASSWORD = os.environ["NEO4J_PASSWORD"] + DATABASE = os.environ.get( + "NEO4J_DATABASE" + ) # If this param is None, the home database will be used. If it is not None, the specified database will be used. + self._DATABASE = DATABASE self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) - return None + _database_name = "home database" if DATABASE is None else f"database {DATABASE}" + with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver: + try: + with _sync_driver.session(database=DATABASE) as session: + try: + session.run("MATCH (n) RETURN n LIMIT 0") + logger.info(f"Connected to {DATABASE} at {URI}") + except neo4jExceptions.ServiceUnavailable as e: + logger.error( + f"{DATABASE} at {URI} is not available".capitalize() + ) + raise e + except neo4jExceptions.AuthError as e: + logger.error(f"Authentication failed for {DATABASE} at {URI}") + raise e + except neo4jExceptions.ClientError as e: + if e.code == "Neo.ClientError.Database.DatabaseNotFound": + logger.info( + f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize() + ) + try: + with _sync_driver.session() as session: + session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS") + logger.info(f"{DATABASE} at {URI} created".capitalize()) + except neo4jExceptions.ClientError as e: + if ( + e.code + == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" + ): + logger.warning( + "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead." + ) + logger.error(f"Failed to create {DATABASE} at {URI}") + raise e def __post_init__(self): self._node_embed_algorithms = { @@ -63,7 +101,7 @@ class Neo4JStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: entity_name_label = node_id.strip('"') - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) @@ -78,7 +116,7 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -91,7 +129,7 @@ class Neo4JStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> Union[dict, None]: - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: entity_name_label = node_id.strip('"') query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) @@ -108,7 +146,7 @@ class Neo4JStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: entity_name_label = node_id.strip('"') - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = f""" MATCH (n:`{entity_name_label}`) RETURN COUNT{{ (n)--() }} AS totalEdgeCount @@ -155,7 +193,7 @@ class Neo4JStorage(BaseGraphStorage): Returns: list: List of all relationships/edges found """ - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = f""" MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties @@ -186,7 +224,7 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: results = await session.run(query) edges = [] async for record in results: @@ -241,7 +279,7 @@ class Neo4JStorage(BaseGraphStorage): ) try: - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") @@ -288,7 +326,7 @@ class Neo4JStorage(BaseGraphStorage): ) try: - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_upsert_edge) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}")