添加选取Neo4j指定数据库功能的支持
This commit is contained in:
@@ -1,18 +1,16 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, Tuple, List, Dict
|
||||
import inspect
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase
|
||||
)
|
||||
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@@ -20,6 +18,9 @@ from tenacity import (
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@@ -38,10 +39,38 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
URI = os.environ["NEO4J_URI"]
|
||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||
DATABASE = os.environ.get(
|
||||
'NEO4J_DATABASE') # If this param is None, the home database will be used. If it is not None, the specified database will be used.
|
||||
self._DATABASE = DATABASE
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
URI, auth=(USERNAME, PASSWORD)
|
||||
)
|
||||
return None
|
||||
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
|
||||
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
|
||||
try:
|
||||
with _sync_driver.session(database=DATABASE) as session:
|
||||
try:
|
||||
session.run('MATCH (n) RETURN n LIMIT 0')
|
||||
logger.info(f"Connected to {DATABASE} at {URI}")
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
logger.error(f"{DATABASE} at {URI} is not available".capitalize())
|
||||
raise e
|
||||
except neo4jExceptions.AuthError as e:
|
||||
logger.error(f"Authentication failed for {DATABASE} at {URI}")
|
||||
raise e
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == 'Neo.ClientError.Database.DatabaseNotFound':
|
||||
logger.info(f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize())
|
||||
try:
|
||||
with _sync_driver.session() as session:
|
||||
session.run(f'CREATE DATABASE `{DATABASE}` IF NOT EXISTS')
|
||||
logger.info(f"{DATABASE} at {URI} created".capitalize())
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Statement.UnsupportedAdministrationCommand":
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead.")
|
||||
logger.error(f"Failed to create {DATABASE} at {URI}")
|
||||
raise e
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
@@ -63,7 +92,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||
)
|
||||
@@ -78,7 +107,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
@@ -91,7 +120,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return single_result["edgeExists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||
result = await session.run(query)
|
||||
@@ -108,7 +137,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (n:`{entity_name_label}`)
|
||||
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
||||
@@ -141,7 +170,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return degrees
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
@@ -155,7 +184,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
@@ -186,7 +215,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
query = f"""MATCH (n:`{node_label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected"""
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
results = await session.run(query)
|
||||
edges = []
|
||||
async for record in results:
|
||||
@@ -212,10 +241,10 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ClientError,
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ClientError,
|
||||
)
|
||||
),
|
||||
)
|
||||
@@ -241,7 +270,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during upsert: {str(e)}")
|
||||
@@ -252,14 +281,14 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
)
|
||||
),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
@@ -288,7 +317,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert_edge)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge upsert: {str(e)}")
|
||||
|
Reference in New Issue
Block a user