refactor database connection management and improve storage lifecycle handling

update
This commit is contained in:
ArnoChen
2025-02-19 03:46:18 +08:00
parent 780d0b45f7
commit e194e04226
6 changed files with 376 additions and 195 deletions

View File

@@ -17,6 +17,7 @@ from .base import (
DocStatusStorage,
QueryParam,
StorageNameSpace,
StoragesStatus,
)
from .namespace import NameSpace, make_namespace
from .operate import (
@@ -348,6 +349,9 @@ class LightRAG:
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
# Ownership
is_managed_by_server: bool = False
"""Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
@@ -440,7 +444,10 @@ class LightRAG:
**self.vector_db_storage_cls_kwargs,
}
# show config
# Life cycle
self.storages_status = StoragesStatus.NOT_CREATED
# Show config
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -547,6 +554,65 @@ class LightRAG:
)
)
self.storages_status = StoragesStatus.CREATED
# Initialize storages
if not self.is_managed_by_server:
loop = always_get_an_event_loop()
loop.run_until_complete(self.initialize_storages())
def __del__(self):
# Finalize storages
if not self.is_managed_by_server:
loop = always_get_an_event_loop()
loop.run_until_complete(self.finalize_storages())
async def initialize_storages(self):
"""Asynchronously initialize the storages"""
if self.storages_status == StoragesStatus.CREATED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.initialize())
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages")
async def finalize_storages(self):
"""Asynchronously finalize the storages"""
if self.storages_status == StoragesStatus.INITIALIZED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.finalize())
await asyncio.gather(*tasks)
logger.debug("Finalized Storages")
self.storages_status = StoragesStatus.FINALIZED
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text