use namespace as neo4j database name

format

fix
This commit is contained in:
ArnoChen
2025-02-08 16:06:07 +08:00
parent 3f845e9e53
commit f5bf6a4af8
3 changed files with 107 additions and 54 deletions

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import inspect import inspect
import os import os
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict from typing import Any, Union, Tuple, List, Dict
import pipmaster as pm import pipmaster as pm
@@ -22,7 +23,7 @@ from tenacity import (
retry_if_exception_type, retry_if_exception_type,
) )
from lightrag.utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
@@ -45,51 +46,69 @@ class Neo4JStorage(BaseGraphStorage):
PASSWORD = os.environ["NEO4J_PASSWORD"] PASSWORD = os.environ["NEO4J_PASSWORD"]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE" "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
) # If this param is None, the home database will be used. If it is not None, the specified database will be used. )
self._DATABASE = DATABASE
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD) URI, auth=(USERNAME, PASSWORD)
) )
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
# Try to connect to the database
with GraphDatabase.driver( with GraphDatabase.driver(
URI, URI,
auth=(USERNAME, PASSWORD), auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
) as _sync_driver: ) as _sync_driver:
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try: try:
with _sync_driver.session(database=DATABASE) as session: with _sync_driver.session(database=database) as session:
try: try:
session.run("MATCH (n) RETURN n LIMIT 0") session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {DATABASE} at {URI}") logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e: except neo4jExceptions.ServiceUnavailable as e:
logger.error( logger.error(
f"{DATABASE} at {URI} is not available".capitalize() f"{database} at {URI} is not available".capitalize()
) )
raise e raise e
except neo4jExceptions.AuthError as e: except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {DATABASE} at {URI}") logger.error(f"Authentication failed for {database} at {URI}")
raise e raise e
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound": if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info( logger.info(
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize() f"{database} at {URI} not found. Try to create specified database.".capitalize()
) )
try: try:
with _sync_driver.session() as session: with _sync_driver.session() as session:
session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS") session.run(
logger.info(f"{DATABASE} at {URI} created".capitalize()) f"CREATE DATABASE `{database}` IF NOT EXISTS"
except neo4jExceptions.ClientError as e: )
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if ( if (
e.code e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand" == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
): ):
if database is not None:
logger.warning( logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead." "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
) )
logger.error(f"Failed to create {DATABASE} at {URI}") if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e raise e
if connected:
break
def __post_init__(self): def __post_init__(self):
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
@@ -117,7 +136,7 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}' f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
) )
return single_result["node_exists"] return single_result["node_exists"]
@@ -133,7 +152,7 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
single_result = await result.single() single_result = await result.single()
logger.debug( logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}' f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
) )
return single_result["edgeExists"] return single_result["edgeExists"]

View File

@@ -257,7 +257,8 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: list[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys]) table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, multirows=True) res = await self.db.query(SQL, params, multirows=True)
@@ -330,7 +331,9 @@ class OracleKVStorage(BaseKVStorage):
return None return None
async def change_status(self, id: str, status: str): async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace)) SQL = SQL_TEMPLATES["change_status"].format(
table_name=namespace_to_table_name(self.namespace)
)
params = {"workspace": self.db.workspace, "id": id, "status": status} params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params) await self.db.execute(SQL, params)
@@ -623,6 +626,7 @@ N_T = {
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
} }
def namespace_to_table_name(namespace: str) -> str: def namespace_to_table_name(namespace: str) -> str:
for k, v in N_T.items(): for k, v in N_T.items():
if is_namespace(namespace, k): if is_namespace(namespace, k):

View File

@@ -236,7 +236,9 @@ class LightRAG:
) )
self.llm_response_cache = self.key_string_value_json_storage_cls( self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -244,15 +246,21 @@ class LightRAG:
# add embedding func by walter # add embedding func by walter
#### ####
self.full_docs = self.key_string_value_json_storage_cls( self.full_docs = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.text_chunks = self.key_string_value_json_storage_cls( self.text_chunks = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.chunk_entity_relation_graph = self.graph_storage_cls( self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION), namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
#### ####
@@ -260,17 +268,23 @@ class LightRAG:
#### ####
self.entities_vdb = self.vector_db_storage_cls( self.entities_vdb = self.vector_db_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES), namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"entity_name"}, meta_fields={"entity_name"},
) )
self.relationships_vdb = self.vector_db_storage_cls( self.relationships_vdb = self.vector_db_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS), namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}, meta_fields={"src_id", "tgt_id"},
) )
self.chunks_vdb = self.vector_db_storage_cls( self.chunks_vdb = self.vector_db_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS), namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -280,7 +294,9 @@ class LightRAG:
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
else: else:
hashing_kv = self.key_string_value_json_storage_cls( hashing_kv = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -931,7 +947,9 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -948,7 +966,9 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -967,7 +987,9 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -1008,7 +1030,9 @@ class LightRAG:
global_config=asdict(self), global_config=asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls( or self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -1039,7 +1063,9 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_funcne, embedding_func=self.embedding_funcne,
), ),
@@ -1055,7 +1081,9 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),
@@ -1074,7 +1102,9 @@ class LightRAG:
if self.llm_response_cache if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config") and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE), namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
), ),