添加选取Neo4j指定数据库功能的支持
This commit is contained in:
@@ -1,18 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, Tuple, List, Dict
|
from typing import Any, Union, Tuple, List, Dict
|
||||||
import inspect
|
|
||||||
from lightrag.utils import logger
|
|
||||||
from ..base import BaseGraphStorage
|
|
||||||
from neo4j import (
|
from neo4j import (
|
||||||
AsyncGraphDatabase,
|
AsyncGraphDatabase,
|
||||||
exceptions as neo4jExceptions,
|
exceptions as neo4jExceptions,
|
||||||
AsyncDriver,
|
AsyncDriver,
|
||||||
AsyncManagedTransaction,
|
AsyncManagedTransaction,
|
||||||
|
GraphDatabase
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
@@ -20,6 +18,9 @@ from tenacity import (
|
|||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lightrag.utils import logger
|
||||||
|
from ..base import BaseGraphStorage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
@@ -38,10 +39,38 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
URI = os.environ["NEO4J_URI"]
|
URI = os.environ["NEO4J_URI"]
|
||||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||||
|
DATABASE = os.environ.get(
|
||||||
|
'NEO4J_DATABASE') # If this param is None, the home database will be used. If it is not None, the specified database will be used.
|
||||||
|
self._DATABASE = DATABASE
|
||||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||||
URI, auth=(USERNAME, PASSWORD)
|
URI, auth=(USERNAME, PASSWORD)
|
||||||
)
|
)
|
||||||
return None
|
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
|
||||||
|
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
|
||||||
|
try:
|
||||||
|
with _sync_driver.session(database=DATABASE) as session:
|
||||||
|
try:
|
||||||
|
session.run('MATCH (n) RETURN n LIMIT 0')
|
||||||
|
logger.info(f"Connected to {DATABASE} at {URI}")
|
||||||
|
except neo4jExceptions.ServiceUnavailable as e:
|
||||||
|
logger.error(f"{DATABASE} at {URI} is not available".capitalize())
|
||||||
|
raise e
|
||||||
|
except neo4jExceptions.AuthError as e:
|
||||||
|
logger.error(f"Authentication failed for {DATABASE} at {URI}")
|
||||||
|
raise e
|
||||||
|
except neo4jExceptions.ClientError as e:
|
||||||
|
if e.code == 'Neo.ClientError.Database.DatabaseNotFound':
|
||||||
|
logger.info(f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize())
|
||||||
|
try:
|
||||||
|
with _sync_driver.session() as session:
|
||||||
|
session.run(f'CREATE DATABASE `{DATABASE}` IF NOT EXISTS')
|
||||||
|
logger.info(f"{DATABASE} at {URI} created".capitalize())
|
||||||
|
except neo4jExceptions.ClientError as e:
|
||||||
|
if e.code == "Neo.ClientError.Statement.UnsupportedAdministrationCommand":
|
||||||
|
logger.warning(
|
||||||
|
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead.")
|
||||||
|
logger.error(f"Failed to create {DATABASE} at {URI}")
|
||||||
|
raise e
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._node_embed_algorithms = {
|
self._node_embed_algorithms = {
|
||||||
@@ -63,7 +92,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
|
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||||
)
|
)
|
||||||
@@ -78,7 +107,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
|
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
||||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||||
@@ -91,7 +120,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return single_result["edgeExists"]
|
return single_result["edgeExists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
@@ -108,7 +137,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = node_id.strip('"')
|
||||||
|
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:`{entity_name_label}`)
|
MATCH (n:`{entity_name_label}`)
|
||||||
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
||||||
@@ -141,7 +170,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
@@ -155,7 +184,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
Returns:
|
Returns:
|
||||||
list: List of all relationships/edges found
|
list: List of all relationships/edges found
|
||||||
"""
|
"""
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||||
RETURN properties(r) as edge_properties
|
RETURN properties(r) as edge_properties
|
||||||
@@ -186,7 +215,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
query = f"""MATCH (n:`{node_label}`)
|
query = f"""MATCH (n:`{node_label}`)
|
||||||
OPTIONAL MATCH (n)-[r]-(connected)
|
OPTIONAL MATCH (n)-[r]-(connected)
|
||||||
RETURN n, r, connected"""
|
RETURN n, r, connected"""
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
results = await session.run(query)
|
results = await session.run(query)
|
||||||
edges = []
|
edges = []
|
||||||
async for record in results:
|
async for record in results:
|
||||||
@@ -212,10 +241,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type(
|
retry=retry_if_exception_type(
|
||||||
(
|
(
|
||||||
neo4jExceptions.ServiceUnavailable,
|
neo4jExceptions.ServiceUnavailable,
|
||||||
neo4jExceptions.TransientError,
|
neo4jExceptions.TransientError,
|
||||||
neo4jExceptions.WriteServiceUnavailable,
|
neo4jExceptions.WriteServiceUnavailable,
|
||||||
neo4jExceptions.ClientError,
|
neo4jExceptions.ClientError,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -241,7 +270,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
await session.execute_write(_do_upsert)
|
await session.execute_write(_do_upsert)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during upsert: {str(e)}")
|
logger.error(f"Error during upsert: {str(e)}")
|
||||||
@@ -252,14 +281,14 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type(
|
retry=retry_if_exception_type(
|
||||||
(
|
(
|
||||||
neo4jExceptions.ServiceUnavailable,
|
neo4jExceptions.ServiceUnavailable,
|
||||||
neo4jExceptions.TransientError,
|
neo4jExceptions.TransientError,
|
||||||
neo4jExceptions.WriteServiceUnavailable,
|
neo4jExceptions.WriteServiceUnavailable,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their labels.
|
||||||
@@ -288,7 +317,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
await session.execute_write(_do_upsert_edge)
|
await session.execute_write(_do_upsert_edge)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during edge upsert: {str(e)}")
|
logger.error(f"Error during edge upsert: {str(e)}")
|
||||||
|
Reference in New Issue
Block a user