From fd76e00c6a4ccb3d92db959c909d8962f71bee5b Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 03:48:19 +0800 Subject: [PATCH] Refactor storage initialization to separate object creation from data loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Split __post_init__ and initialize() • Move data loading to initialize() • Add FastAPI lifespan integration --- lightrag/api/lightrag_server.py | 14 ++++++++------ lightrag/kg/json_doc_status_impl.py | 5 ++++- lightrag/kg/json_kv_impl.py | 5 ++++- lightrag/kg/nano_vector_db_impl.py | 13 ++++++++----- lightrag/kg/shared_storage.py | 24 ++++++++++++++++-------- 5 files changed, 40 insertions(+), 21 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index c49de7a4..ca0958ee 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -135,14 +135,16 @@ def create_app(args): # Initialize database connections await rag.initialize_storages() + # Import necessary functions from shared_storage + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_storage_lock, + initialize_pipeline_namespace, + ) + await initialize_pipeline_namespace() + # Auto scan documents if enabled if args.auto_scan_at_startup: - # Import necessary functions from shared_storage - from lightrag.kg.shared_storage import ( - get_namespace_data, - get_storage_lock, - ) - # Check if a task is already running (with lock protection) pipeline_status = await get_namespace_data("pipeline_status") should_start_task = False diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 6a825db4..01c657fa 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -24,11 +24,14 @@ from .shared_storage import ( class JsonDocStatusStorage(DocStatusStorage): """JSON implementation of document status storage""" - async def __post_init__(self): + def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + self._data = None + async def initialize(self): + """Initialize storage data""" # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) self._data = await get_namespace_data(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 424730c1..8d707899 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -20,11 +20,14 @@ from .shared_storage import ( @final @dataclass class JsonKVStorage(BaseKVStorage): - async def __post_init__(self): + def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + self._data = None + async def initialize(self): + """Initialize storage data""" # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) self._data = await get_namespace_data(self.namespace) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index bbf991bf..86381379 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -16,15 +16,16 @@ if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB -from threading import Lock as ThreadLock +from .shared_storage import get_storage_lock @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): - # Initialize lock only for file operations - self._storage_lock = ThreadLock() + # Initialize basic attributes + self._storage_lock = get_storage_lock() + self._client = None # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -40,7 +41,9 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] - with self._storage_lock: + async def initialize(self): + """Initialize storage data""" + async with self._storage_lock: self._client = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name, @@ -163,5 +166,5 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: - with self._storage_lock: + async with self._storage_lock: self._get_client().save() diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index ef946b44..5f795f0f 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -125,13 +125,21 @@ def initialize_share_data(workers: int = 1): # Mark as initialized _initialized = True - # Initialize pipeline status for document indexing control - pipeline_namespace = get_namespace_data("pipeline_status") - # Create a shared list object for history_messages - history_messages = _manager.list() if is_multiprocess else [] - pipeline_namespace.update( - { +async def initialize_pipeline_namespace(): + """ + Initialize pipeline namespace with default values. + """ + pipeline_namespace = await get_namespace_data("pipeline_status") + + async with get_storage_lock(): + # Check if already initialized by checking for required fields + if "busy" in pipeline_namespace: + return + + # Create a shared list object for history_messages + history_messages = _manager.list() if is_multiprocess else [] + pipeline_namespace.update({ "busy": False, # Control concurrent processes "job_name": "Default Job", # Current job name (indexing files/indexing texts) "job_start": None, # Job start time @@ -141,8 +149,8 @@ def initialize_share_data(workers: int = 1): "request_pending": False, # Flag for pending request for processing "latest_message": "", # Latest message from pipeline processing "history_messages": history_messages, # 使用共享列表对象 - } - ) + }) + direct_log(f"Process {os.getpid()} Pipeline namespace initialized") async def get_update_flags(namespace: str):