From 1e3b25db22151a2240c4fcdfaf50fd06727d7340 Mon Sep 17 00:00:00 2001 From: xiyihan <71264788+xiyihan0@users.noreply.github.com> Date: Sat, 4 Jan 2025 22:33:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=80=89=E5=8F=96Neo4j?= =?UTF-8?q?=E6=8C=87=E5=AE=9A=E6=95=B0=E6=8D=AE=E5=BA=93=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E7=9A=84=E6=94=AF=E6=8C=81(fix=20lint)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/kg/neo4j_impl.py | 45 +++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 70f387b7..884fcb40 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -9,7 +9,7 @@ from neo4j import ( exceptions as neo4jExceptions, AsyncDriver, AsyncManagedTransaction, - GraphDatabase + GraphDatabase, ) from tenacity import ( retry, @@ -40,7 +40,8 @@ class Neo4JStorage(BaseGraphStorage): USERNAME = os.environ["NEO4J_USERNAME"] PASSWORD = os.environ["NEO4J_PASSWORD"] DATABASE = os.environ.get( - 'NEO4J_DATABASE') # If this param is None, the home database will be used. If it is not None, the specified database will be used. + "NEO4J_DATABASE" + ) # If this param is None, the home database will be used. If it is not None, the specified database will be used. self._DATABASE = DATABASE self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) @@ -50,25 +51,33 @@ class Neo4JStorage(BaseGraphStorage): try: with _sync_driver.session(database=DATABASE) as session: try: - session.run('MATCH (n) RETURN n LIMIT 0') + session.run("MATCH (n) RETURN n LIMIT 0") logger.info(f"Connected to {DATABASE} at {URI}") except neo4jExceptions.ServiceUnavailable as e: - logger.error(f"{DATABASE} at {URI} is not available".capitalize()) + logger.error( + f"{DATABASE} at {URI} is not available".capitalize() + ) raise e except neo4jExceptions.AuthError as e: logger.error(f"Authentication failed for {DATABASE} at {URI}") raise e except neo4jExceptions.ClientError as e: - if e.code == 'Neo.ClientError.Database.DatabaseNotFound': - logger.info(f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()) + if e.code == "Neo.ClientError.Database.DatabaseNotFound": + logger.info( + f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize() + ) try: with _sync_driver.session() as session: - session.run(f'CREATE DATABASE `{DATABASE}` IF NOT EXISTS') + session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS") logger.info(f"{DATABASE} at {URI} created".capitalize()) except neo4jExceptions.ClientError as e: - if e.code == "Neo.ClientError.Statement.UnsupportedAdministrationCommand": + if ( + e.code + == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" + ): logger.warning( - "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead.") + "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead." + ) logger.error(f"Failed to create {DATABASE} at {URI}") raise e @@ -170,7 +179,7 @@ class Neo4JStorage(BaseGraphStorage): return degrees async def get_edge( - self, source_node_id: str, target_node_id: str + self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -241,10 +250,10 @@ class Neo4JStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.WriteServiceUnavailable, - neo4jExceptions.ClientError, + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, ) ), ) @@ -281,14 +290,14 @@ class Neo4JStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, ) ), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] + self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] ): """ Upsert an edge and its properties between two nodes identified by their labels.