diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f4e18446..fe01aaf3 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,6 +1,7 @@ import asyncio import inspect import os +import re from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict import pipmaster as pm @@ -22,7 +23,7 @@ from tenacity import ( retry_if_exception_type, ) -from lightrag.utils import logger +from ..utils import logger from ..base import BaseGraphStorage @@ -45,50 +46,68 @@ class Neo4JStorage(BaseGraphStorage): PASSWORD = os.environ["NEO4J_PASSWORD"] MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) DATABASE = os.environ.get( - "NEO4J_DATABASE" - ) # If this param is None, the home database will be used. If it is not None, the specified database will be used. - self._DATABASE = DATABASE + "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) + ) self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) - _database_name = "home database" if DATABASE is None else f"database {DATABASE}" + + # Try to connect to the database with GraphDatabase.driver( URI, auth=(USERNAME, PASSWORD), max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, ) as _sync_driver: - try: - with _sync_driver.session(database=DATABASE) as session: - try: - session.run("MATCH (n) RETURN n LIMIT 0") - logger.info(f"Connected to {DATABASE} at {URI}") - except neo4jExceptions.ServiceUnavailable as e: - logger.error( - f"{DATABASE} at {URI} is not available".capitalize() - ) - raise e - except neo4jExceptions.AuthError as e: - logger.error(f"Authentication failed for {DATABASE} at {URI}") - raise e - except neo4jExceptions.ClientError as e: - if e.code == "Neo.ClientError.Database.DatabaseNotFound": - logger.info( - f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize() - ) + for database in (DATABASE, None): + self._DATABASE = database + connected = False + try: - with _sync_driver.session() as session: - session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS") - logger.info(f"{DATABASE} at {URI} created".capitalize()) - except neo4jExceptions.ClientError as e: - if ( - e.code - == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" - ): - logger.warning( - "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead." - ) - logger.error(f"Failed to create {DATABASE} at {URI}") + with _sync_driver.session(database=database) as session: + try: + session.run("MATCH (n) RETURN n LIMIT 0") + logger.info(f"Connected to {database} at {URI}") + connected = True + except neo4jExceptions.ServiceUnavailable as e: + logger.error( + f"{database} at {URI} is not available".capitalize() + ) + raise e + except neo4jExceptions.AuthError as e: + logger.error(f"Authentication failed for {database} at {URI}") raise e + except neo4jExceptions.ClientError as e: + if e.code == "Neo.ClientError.Database.DatabaseNotFound": + logger.info( + f"{database} at {URI} not found. Try to create specified database.".capitalize() + ) + try: + with _sync_driver.session() as session: + session.run( + f"CREATE DATABASE `{database}` IF NOT EXISTS" + ) + logger.info(f"{database} at {URI} created".capitalize()) + connected = True + except ( + neo4jExceptions.ClientError, + neo4jExceptions.DatabaseError, + ) as e: + if ( + e.code + == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" + ) or ( + e.code == "Neo.DatabaseError.Statement.ExecutionFailed" + ): + if database is not None: + logger.warning( + "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." + ) + if database is None: + logger.error(f"Failed to create {database} at {URI}") + raise e + + if connected: + break def __post_init__(self): self._node_embed_algorithms = { @@ -117,7 +136,7 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) single_result = await result.single() logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" ) return single_result["node_exists"] @@ -133,7 +152,7 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) single_result = await result.single() logger.debug( - f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" ) return single_result["edgeExists"] diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 32fbaa10..a1a05759 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -257,7 +257,8 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys]) + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), ) params = {"workspace": self.db.workspace} res = await self.db.query(SQL, params, multirows=True) @@ -330,7 +331,9 @@ class OracleKVStorage(BaseKVStorage): return None async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace)) + SQL = SQL_TEMPLATES["change_status"].format( + table_name=namespace_to_table_name(self.namespace) + ) params = {"workspace": self.db.workspace, "id": id, "status": status} await self.db.execute(SQL, params) @@ -623,6 +626,7 @@ N_T = { NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", } + def namespace_to_table_name(namespace: str) -> str: for k, v in N_T.items(): if is_namespace(namespace, k): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 242c6832..6b925be3 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -236,7 +236,9 @@ class LightRAG: ) self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), embedding_func=self.embedding_func, ) @@ -244,15 +246,21 @@ class LightRAG: # add embedding func by walter #### self.full_docs = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS + ), embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS + ), embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION), + namespace=make_namespace( + self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION + ), embedding_func=self.embedding_func, ) #### @@ -260,17 +268,23 @@ class LightRAG: #### self.entities_vdb = self.vector_db_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES), + namespace=make_namespace( + self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES + ), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) self.relationships_vdb = self.vector_db_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS + ), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) self.chunks_vdb = self.vector_db_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS), + namespace=make_namespace( + self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS + ), embedding_func=self.embedding_func, ) @@ -280,7 +294,9 @@ class LightRAG: hashing_kv = self.llm_response_cache else: hashing_kv = self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), embedding_func=self.embedding_func, ) @@ -931,7 +947,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -948,7 +966,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -967,7 +987,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1008,7 +1030,9 @@ class LightRAG: global_config=asdict(self), hashing_kv=self.llm_response_cache or self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1039,7 +1063,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_funcne, ), @@ -1055,7 +1081,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ), @@ -1074,7 +1102,9 @@ class LightRAG: if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), global_config=asdict(self), embedding_func=self.embedding_func, ),