Merge branch 'HKUDS:main' into main

This commit is contained in:
Samuel Chan
2025-01-06 12:53:06 +08:00
committed by GitHub

View File

@@ -1,18 +1,16 @@
import asyncio
import inspect
import os
from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict
import inspect
from lightrag.utils import logger
from ..base import BaseGraphStorage
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
from tenacity import (
retry,
stop_after_attempt,
@@ -20,6 +18,9 @@ from tenacity import (
retry_if_exception_type,
)
from lightrag.utils import logger
from ..base import BaseGraphStorage
@dataclass
class Neo4JStorage(BaseGraphStorage):
@@ -38,10 +39,47 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
DATABASE = os.environ.get(
"NEO4J_DATABASE"
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
self._DATABASE = DATABASE
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
try:
with _sync_driver.session(database=DATABASE) as session:
try:
session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {DATABASE} at {URI}")
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{DATABASE} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {DATABASE} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
)
try:
with _sync_driver.session() as session:
session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
logger.info(f"{DATABASE} at {URI} created".capitalize())
except neo4jExceptions.ClientError as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
):
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
)
logger.error(f"Failed to create {DATABASE} at {URI}")
raise e
def __post_init__(self):
self._node_embed_algorithms = {
@@ -63,7 +101,7 @@ class Neo4JStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
@@ -78,7 +116,7 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
@@ -91,7 +129,7 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
@@ -108,7 +146,7 @@ class Neo4JStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
@@ -155,7 +193,7 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
@@ -186,7 +224,7 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
results = await session.run(query)
edges = []
async for record in results:
@@ -241,7 +279,7 @@ class Neo4JStorage(BaseGraphStorage):
)
try:
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
@@ -288,7 +326,7 @@ class Neo4JStorage(BaseGraphStorage):
)
try:
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_upsert_edge)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")