添加选取Neo4j指定数据库功能的支持(fix lint)

This commit is contained in:
xiyihan
2025-01-04 22:33:35 +08:00
committed by GitHub
parent bb4c271623
commit 1e3b25db22

View File

@@ -9,7 +9,7 @@ from neo4j import (
exceptions as neo4jExceptions, exceptions as neo4jExceptions,
AsyncDriver, AsyncDriver,
AsyncManagedTransaction, AsyncManagedTransaction,
GraphDatabase GraphDatabase,
) )
from tenacity import ( from tenacity import (
retry, retry,
@@ -40,7 +40,8 @@ class Neo4JStorage(BaseGraphStorage):
USERNAME = os.environ["NEO4J_USERNAME"] USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"] PASSWORD = os.environ["NEO4J_PASSWORD"]
DATABASE = os.environ.get( DATABASE = os.environ.get(
'NEO4J_DATABASE') # If this param is None, the home database will be used. If it is not None, the specified database will be used. "NEO4J_DATABASE"
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
self._DATABASE = DATABASE self._DATABASE = DATABASE
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD) URI, auth=(USERNAME, PASSWORD)
@@ -50,25 +51,33 @@ class Neo4JStorage(BaseGraphStorage):
try: try:
with _sync_driver.session(database=DATABASE) as session: with _sync_driver.session(database=DATABASE) as session:
try: try:
session.run('MATCH (n) RETURN n LIMIT 0') session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {DATABASE} at {URI}") logger.info(f"Connected to {DATABASE} at {URI}")
except neo4jExceptions.ServiceUnavailable as e: except neo4jExceptions.ServiceUnavailable as e:
logger.error(f"{DATABASE} at {URI} is not available".capitalize()) logger.error(
f"{DATABASE} at {URI} is not available".capitalize()
)
raise e raise e
except neo4jExceptions.AuthError as e: except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {DATABASE} at {URI}") logger.error(f"Authentication failed for {DATABASE} at {URI}")
raise e raise e
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
if e.code == 'Neo.ClientError.Database.DatabaseNotFound': if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()) logger.info(
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
)
try: try:
with _sync_driver.session() as session: with _sync_driver.session() as session:
session.run(f'CREATE DATABASE `{DATABASE}` IF NOT EXISTS') session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
logger.info(f"{DATABASE} at {URI} created".capitalize()) logger.info(f"{DATABASE} at {URI} created".capitalize())
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Statement.UnsupportedAdministrationCommand": if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
):
logger.warning( logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead.") "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
)
logger.error(f"Failed to create {DATABASE} at {URI}") logger.error(f"Failed to create {DATABASE} at {URI}")
raise e raise e
@@ -170,7 +179,7 @@ class Neo4JStorage(BaseGraphStorage):
return degrees return degrees
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
@@ -241,10 +250,10 @@ class Neo4JStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type( retry=retry_if_exception_type(
( (
neo4jExceptions.ServiceUnavailable, neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError, neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError, neo4jExceptions.ClientError,
) )
), ),
) )
@@ -281,14 +290,14 @@ class Neo4JStorage(BaseGraphStorage):
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type( retry=retry_if_exception_type(
( (
neo4jExceptions.ServiceUnavailable, neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError, neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.WriteServiceUnavailable,
) )
), ),
) )
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
): ):
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels.