添加选取Neo4j指定数据库功能的支持

This commit is contained in:
xiyihan
2025-01-04 21:47:52 +08:00
committed by GitHub
parent fb414582c7
commit bb4c271623

View File

@@ -1,18 +1,16 @@
import asyncio import asyncio
import inspect
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict from typing import Any, Union, Tuple, List, Dict
import inspect
from lightrag.utils import logger
from ..base import BaseGraphStorage
from neo4j import ( from neo4j import (
AsyncGraphDatabase, AsyncGraphDatabase,
exceptions as neo4jExceptions, exceptions as neo4jExceptions,
AsyncDriver, AsyncDriver,
AsyncManagedTransaction, AsyncManagedTransaction,
GraphDatabase
) )
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@@ -20,6 +18,9 @@ from tenacity import (
retry_if_exception_type, retry_if_exception_type,
) )
from lightrag.utils import logger
from ..base import BaseGraphStorage
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@@ -38,10 +39,38 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"] URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"] USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"] PASSWORD = os.environ["NEO4J_PASSWORD"]
DATABASE = os.environ.get(
'NEO4J_DATABASE') # If this param is None, the home database will be used. If it is not None, the specified database will be used.
self._DATABASE = DATABASE
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD) URI, auth=(USERNAME, PASSWORD)
) )
return None _database_name = "home database" if DATABASE is None else f"database {DATABASE}"
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
try:
with _sync_driver.session(database=DATABASE) as session:
try:
session.run('MATCH (n) RETURN n LIMIT 0')
logger.info(f"Connected to {DATABASE} at {URI}")
except neo4jExceptions.ServiceUnavailable as e:
logger.error(f"{DATABASE} at {URI} is not available".capitalize())
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {DATABASE} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == 'Neo.ClientError.Database.DatabaseNotFound':
logger.info(f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize())
try:
with _sync_driver.session() as session:
session.run(f'CREATE DATABASE `{DATABASE}` IF NOT EXISTS')
logger.info(f"{DATABASE} at {URI} created".capitalize())
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Statement.UnsupportedAdministrationCommand":
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead.")
logger.error(f"Failed to create {DATABASE} at {URI}")
raise e
def __post_init__(self): def __post_init__(self):
self._node_embed_algorithms = { self._node_embed_algorithms = {
@@ -63,7 +92,7 @@ class Neo4JStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
query = ( query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
) )
@@ -78,7 +107,7 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
query = ( query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists" "RETURN COUNT(r) > 0 AS edgeExists"
@@ -91,7 +120,7 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"] return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query) result = await session.run(query)
@@ -108,7 +137,7 @@ class Neo4JStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"') entity_name_label = node_id.strip('"')
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
query = f""" query = f"""
MATCH (n:`{entity_name_label}`) MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount RETURN COUNT{{ (n)--() }} AS totalEdgeCount
@@ -155,7 +184,7 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
list: List of all relationships/edges found list: List of all relationships/edges found
""" """
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
query = f""" query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
@@ -186,7 +215,7 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`) query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected"""
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
results = await session.run(query) results = await session.run(query)
edges = [] edges = []
async for record in results: async for record in results:
@@ -241,7 +270,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
try: try:
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_upsert) await session.execute_write(_do_upsert)
except Exception as e: except Exception as e:
logger.error(f"Error during upsert: {str(e)}") logger.error(f"Error during upsert: {str(e)}")
@@ -288,7 +317,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
try: try:
async with self._driver.session() as session: async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_upsert_edge) await session.execute_write(_do_upsert_edge)
except Exception as e: except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}") logger.error(f"Error during edge upsert: {str(e)}")