@@ -1,18 +1,16 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, Tuple, List, Dict
|
||||
import inspect
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@@ -20,6 +18,9 @@ from tenacity import (
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@@ -38,10 +39,47 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
URI = os.environ["NEO4J_URI"]
|
||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||
DATABASE = os.environ.get(
|
||||
"NEO4J_DATABASE"
|
||||
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
|
||||
self._DATABASE = DATABASE
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
URI, auth=(USERNAME, PASSWORD)
|
||||
)
|
||||
return None
|
||||
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
|
||||
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
|
||||
try:
|
||||
with _sync_driver.session(database=DATABASE) as session:
|
||||
try:
|
||||
session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
logger.info(f"Connected to {DATABASE} at {URI}")
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
logger.error(
|
||||
f"{DATABASE} at {URI} is not available".capitalize()
|
||||
)
|
||||
raise e
|
||||
except neo4jExceptions.AuthError as e:
|
||||
logger.error(f"Authentication failed for {DATABASE} at {URI}")
|
||||
raise e
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||
logger.info(
|
||||
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
|
||||
)
|
||||
try:
|
||||
with _sync_driver.session() as session:
|
||||
session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
|
||||
logger.info(f"{DATABASE} at {URI} created".capitalize())
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
):
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
|
||||
)
|
||||
logger.error(f"Failed to create {DATABASE} at {URI}")
|
||||
raise e
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
@@ -63,7 +101,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||
)
|
||||
@@ -78,7 +116,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
@@ -91,7 +129,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return single_result["edgeExists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||
result = await session.run(query)
|
||||
@@ -108,7 +146,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (n:`{entity_name_label}`)
|
||||
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
||||
@@ -155,7 +193,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
@@ -186,7 +224,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
query = f"""MATCH (n:`{node_label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected"""
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
results = await session.run(query)
|
||||
edges = []
|
||||
async for record in results:
|
||||
@@ -241,7 +279,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during upsert: {str(e)}")
|
||||
@@ -288,7 +326,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert_edge)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge upsert: {str(e)}")
|
||||
|
Reference in New Issue
Block a user