use namespace as neo4j database name

format

fix
This commit is contained in:
ArnoChen
2025-02-08 16:06:07 +08:00
parent 3f845e9e53
commit f5bf6a4af8
3 changed files with 107 additions and 54 deletions

View File

@@ -1,6 +1,7 @@
import asyncio
import inspect
import os
import re
from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict
import pipmaster as pm
@@ -22,7 +23,7 @@ from tenacity import (
retry_if_exception_type,
)
from lightrag.utils import logger
from ..utils import logger
from ..base import BaseGraphStorage
@@ -45,51 +46,69 @@ class Neo4JStorage(BaseGraphStorage):
PASSWORD = os.environ["NEO4J_PASSWORD"]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
DATABASE = os.environ.get(
"NEO4J_DATABASE"
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
self._DATABASE = DATABASE
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
)
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
# Try to connect to the database
with GraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
) as _sync_driver:
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try:
with _sync_driver.session(database=DATABASE) as session:
with _sync_driver.session(database=database) as session:
try:
session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {DATABASE} at {URI}")
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{DATABASE} at {URI} is not available".capitalize()
f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {DATABASE} at {URI}")
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
with _sync_driver.session() as session:
session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
logger.info(f"{DATABASE} at {URI} created".capitalize())
except neo4jExceptions.ClientError as e:
session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
logger.error(f"Failed to create {DATABASE} at {URI}")
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if connected:
break
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
@@ -117,7 +136,7 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
)
return single_result["node_exists"]
@@ -133,7 +152,7 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query)
single_result = await result.single()
logger.debug(
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
)
return single_result["edgeExists"]

View File

@@ -257,7 +257,8 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys])
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, multirows=True)
@@ -330,7 +331,9 @@ class OracleKVStorage(BaseKVStorage):
return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace))
SQL = SQL_TEMPLATES["change_status"].format(
table_name=namespace_to_table_name(self.namespace)
)
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
@@ -623,6 +626,7 @@ N_T = {
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
}
def namespace_to_table_name(namespace: str) -> str:
for k, v in N_T.items():
if is_namespace(namespace, k):

View File

@@ -236,7 +236,9 @@ class LightRAG:
)
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
@@ -244,15 +246,21 @@ class LightRAG:
# add embedding func by walter
####
self.full_docs = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
embedding_func=self.embedding_func,
)
self.text_chunks = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION),
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
embedding_func=self.embedding_func,
)
####
@@ -260,17 +268,23 @@ class LightRAG:
####
self.entities_vdb = self.vector_db_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES),
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS),
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS),
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
embedding_func=self.embedding_func,
)
@@ -280,7 +294,9 @@ class LightRAG:
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
@@ -931,7 +947,9 @@ class LightRAG:
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
@@ -948,7 +966,9 @@ class LightRAG:
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
@@ -967,7 +987,9 @@ class LightRAG:
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
@@ -1008,7 +1030,9 @@ class LightRAG:
global_config=asdict(self),
hashing_kv=self.llm_response_cache
or self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
@@ -1039,7 +1063,9 @@ class LightRAG:
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_funcne,
),
@@ -1055,7 +1081,9 @@ class LightRAG:
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
@@ -1074,7 +1102,9 @@ class LightRAG:
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),