From bb4c271623c84498b95a6b47ef295382f35ad34e Mon Sep 17 00:00:00 2001 From: xiyihan <71264788+xiyihan0@users.noreply.github.com> Date: Sat, 4 Jan 2025 21:47:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=80=89=E5=8F=96Neo4j?= =?UTF-8?q?=E6=8C=87=E5=AE=9A=E6=95=B0=E6=8D=AE=E5=BA=93=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/kg/neo4j_impl.py | 75 +++++++++++++++++++++++++++------------ 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 5db3016f..70f387b7 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,18 +1,16 @@ import asyncio +import inspect import os from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict -import inspect -from lightrag.utils import logger -from ..base import BaseGraphStorage + from neo4j import ( AsyncGraphDatabase, exceptions as neo4jExceptions, AsyncDriver, AsyncManagedTransaction, + GraphDatabase ) - - from tenacity import ( retry, stop_after_attempt, @@ -20,6 +18,9 @@ from tenacity import ( retry_if_exception_type, ) +from lightrag.utils import logger +from ..base import BaseGraphStorage + @dataclass class Neo4JStorage(BaseGraphStorage): @@ -38,10 +39,38 @@ class Neo4JStorage(BaseGraphStorage): URI = os.environ["NEO4J_URI"] USERNAME = os.environ["NEO4J_USERNAME"] PASSWORD = os.environ["NEO4J_PASSWORD"] + DATABASE = os.environ.get( + 'NEO4J_DATABASE') # If this param is None, the home database will be used. If it is not None, the specified database will be used. + self._DATABASE = DATABASE self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) - return None + _database_name = "home database" if DATABASE is None else f"database {DATABASE}" + with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver: + try: + with _sync_driver.session(database=DATABASE) as session: + try: + session.run('MATCH (n) RETURN n LIMIT 0') + logger.info(f"Connected to {DATABASE} at {URI}") + except neo4jExceptions.ServiceUnavailable as e: + logger.error(f"{DATABASE} at {URI} is not available".capitalize()) + raise e + except neo4jExceptions.AuthError as e: + logger.error(f"Authentication failed for {DATABASE} at {URI}") + raise e + except neo4jExceptions.ClientError as e: + if e.code == 'Neo.ClientError.Database.DatabaseNotFound': + logger.info(f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()) + try: + with _sync_driver.session() as session: + session.run(f'CREATE DATABASE `{DATABASE}` IF NOT EXISTS') + logger.info(f"{DATABASE} at {URI} created".capitalize()) + except neo4jExceptions.ClientError as e: + if e.code == "Neo.ClientError.Statement.UnsupportedAdministrationCommand": + logger.warning( + "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead.") + logger.error(f"Failed to create {DATABASE} at {URI}") + raise e def __post_init__(self): self._node_embed_algorithms = { @@ -63,7 +92,7 @@ class Neo4JStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: entity_name_label = node_id.strip('"') - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) @@ -78,7 +107,7 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -91,7 +120,7 @@ class Neo4JStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> Union[dict, None]: - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: entity_name_label = node_id.strip('"') query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) @@ -108,7 +137,7 @@ class Neo4JStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: entity_name_label = node_id.strip('"') - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = f""" MATCH (n:`{entity_name_label}`) RETURN COUNT{{ (n)--() }} AS totalEdgeCount @@ -141,7 +170,7 @@ class Neo4JStorage(BaseGraphStorage): return degrees async def get_edge( - self, source_node_id: str, target_node_id: str + self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -155,7 +184,7 @@ class Neo4JStorage(BaseGraphStorage): Returns: list: List of all relationships/edges found """ - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: query = f""" MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties @@ -186,7 +215,7 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: results = await session.run(query) edges = [] async for record in results: @@ -212,10 +241,10 @@ class Neo4JStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.WriteServiceUnavailable, - neo4jExceptions.ClientError, + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, ) ), ) @@ -241,7 +270,7 @@ class Neo4JStorage(BaseGraphStorage): ) try: - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") @@ -252,14 +281,14 @@ class Neo4JStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, ) ), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] + self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] ): """ Upsert an edge and its properties between two nodes identified by their labels. @@ -288,7 +317,7 @@ class Neo4JStorage(BaseGraphStorage): ) try: - async with self._driver.session() as session: + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_upsert_edge) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}")