diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 4d0a6390..c49de7a4 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -143,13 +143,10 @@ def create_app(args): get_storage_lock, ) - # Get pipeline status and lock - pipeline_status = get_namespace_data("pipeline_status") - storage_lock = get_storage_lock() - # Check if a task is already running (with lock protection) + pipeline_status = await get_namespace_data("pipeline_status") should_start_task = False - with storage_lock: + async with get_storage_lock(): if not pipeline_status.get("busy", False): should_start_task = True # Only start the task if no other task is running diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 05e6da37..6a825db4 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -24,17 +24,17 @@ from .shared_storage import ( class JsonDocStatusStorage(DocStatusStorage): """JSON implementation of document status storage""" - def __post_init__(self): + async def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) - self._data = get_namespace_data(self.namespace) + self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} - with self._storage_lock: + async with self._storage_lock: self._data.update(loaded_data) logger.info( f"Loaded document status storage with {len(loaded_data)} records" @@ -42,12 +42,12 @@ class JsonDocStatusStorage(DocStatusStorage): async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - with self._storage_lock: + async with self._storage_lock: return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] - with self._storage_lock: + async with self._storage_lock: for id in ids: data = self._data.get(id, None) if data: @@ -57,7 +57,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" counts = {status.value: 0 for status in DocStatus} - with self._storage_lock: + async with self._storage_lock: for doc in self._data.values(): counts[doc["status"]] += 1 return counts @@ -67,7 +67,7 @@ class JsonDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" result = {} - with self._storage_lock: + async with self._storage_lock: for k, v in self._data.items(): if v["status"] == status.value: try: @@ -83,7 +83,7 @@ class JsonDocStatusStorage(DocStatusStorage): return result async def index_done_callback(self) -> None: - with self._storage_lock: + async with self._storage_lock: data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) @@ -94,21 +94,21 @@ class JsonDocStatusStorage(DocStatusStorage): if not data: return - with self._storage_lock: + async with self._storage_lock: self._data.update(data) await self.index_done_callback() async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - with self._storage_lock: + async with self._storage_lock: return self._data.get(id) async def delete(self, doc_ids: list[str]): - with self._storage_lock: + async with self._storage_lock: for doc_id in doc_ids: self._data.pop(doc_id, None) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" - with self._storage_lock: + async with self._storage_lock: self._data.clear() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index a4ce91a5..424730c1 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -20,33 +20,33 @@ from .shared_storage import ( @final @dataclass class JsonKVStorage(BaseKVStorage): - def __post_init__(self): + async def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) - self._data = get_namespace_data(self.namespace) + self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} - with self._storage_lock: + async with self._storage_lock: self._data.update(loaded_data) logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") async def index_done_callback(self) -> None: - with self._storage_lock: + async with self._storage_lock: data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) write_json(data_dict, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: - with self._storage_lock: + async with self._storage_lock: return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - with self._storage_lock: + async with self._storage_lock: return [ ( {k: v for k, v in self._data[id].items()} @@ -57,19 +57,19 @@ class JsonKVStorage(BaseKVStorage): ] async def filter_keys(self, keys: set[str]) -> set[str]: - with self._storage_lock: + async with self._storage_lock: return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - with self._storage_lock: + async with self._storage_lock: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) async def delete(self, ids: list[str]) -> None: - with self._storage_lock: + async with self._storage_lock: for doc_id in ids: self._data.pop(doc_id, None) await self.index_done_callback() diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 7ac0d625..ef946b44 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -3,7 +3,7 @@ import sys import asyncio from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing import Manager -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes @@ -15,6 +15,43 @@ def direct_log(message, level="INFO"): print(f"{level}: {message}", file=sys.stderr, flush=True) +T = TypeVar('T') + +class UnifiedLock(Generic[T]): + """统一的锁包装类,提供同步和异步的统一接口""" + def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): + self._lock = lock + self._is_async = is_async + + async def __aenter__(self) -> 'UnifiedLock[T]': + """异步上下文管理器入口""" + if self._is_async: + await self._lock.acquire() + else: + self._lock.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + if self._is_async: + self._lock.release() + else: + self._lock.release() + + def __enter__(self) -> 'UnifiedLock[T]': + """同步上下文管理器入口(仅用于向后兼容)""" + if self._is_async: + raise RuntimeError("Use 'async with' for asyncio.Lock") + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """同步上下文管理器出口(仅用于向后兼容)""" + if self._is_async: + raise RuntimeError("Use 'async with' for asyncio.Lock") + self._lock.release() + + LockType = Union[ProcessLock, asyncio.Lock] is_multiprocess = None @@ -117,26 +154,21 @@ async def get_update_flags(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - if is_multiprocess: - with _global_lock: - if namespace not in _update_flags: - if _manager is not None: - _update_flags[namespace] = _manager.list() - direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") - - if _manager is not None: - new_update_flag = _manager.Value('b', False) - _update_flags[namespace].append(new_update_flag) - return new_update_flag - else: - async with _global_lock: - if namespace not in _update_flags: + async with get_storage_lock(): + if namespace not in _update_flags: + if is_multiprocess and _manager is not None: + _update_flags[namespace] = _manager.list() + else: _update_flags[namespace] = [] - direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") - + direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") + + if is_multiprocess and _manager is not None: + new_update_flag = _manager.Value('b', False) + else: new_update_flag = False - _update_flags[namespace].append(new_update_flag) - return new_update_flag + + _update_flags[namespace].append(new_update_flag) + return new_update_flag async def set_update_flag(namespace: str): """Set all update flag of namespace to indicate storage needs updating""" @@ -144,19 +176,14 @@ async def set_update_flag(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - if is_multiprocess: - with _global_lock: - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") - # Update flags for multiprocess mode - for i in range(len(_update_flags[namespace])): + async with get_storage_lock(): + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[namespace])): + if is_multiprocess: _update_flags[namespace][i].value = True - else: - async with _global_lock: - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") - # Update flags for single process mode - for i in range(len(_update_flags[namespace])): + else: _update_flags[namespace][i] = True @@ -182,9 +209,12 @@ def try_initialize_namespace(namespace: str) -> bool: return False -def get_storage_lock() -> LockType: - """return storage lock for data consistency""" - return _global_lock +def get_storage_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock( + lock=_global_lock, + is_async=not is_multiprocess + ) async def get_namespace_data(namespace: str) -> Dict[str, Any]: @@ -196,14 +226,11 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]: ) raise ValueError("Shared dictionaries not initialized") - if is_multiprocess: - with _global_lock: - if namespace not in _shared_dicts: - if _manager is not None: - _shared_dicts[namespace] = _manager.dict() - else: - async with _global_lock: - if namespace not in _shared_dicts: + async with get_storage_lock(): + if namespace not in _shared_dicts: + if is_multiprocess and _manager is not None: + _shared_dicts[namespace] = _manager.dict() + else: _shared_dicts[namespace] = {} return _shared_dicts[namespace] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9c8f84ff..4b85a3b7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -672,12 +672,12 @@ class LightRAG: from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock # Get pipeline status shared data and lock - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") storage_lock = get_storage_lock() # Check if another process is already processing the queue process_documents = False - with storage_lock: + async with storage_lock: # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): # Cleaning history_messages without breaking it as a shared list object @@ -732,8 +732,7 @@ class LightRAG: break # Update pipeline status with document count (with lock) - with storage_lock: - pipeline_status["docs"] = len(to_process_docs) + pipeline_status["docs"] = len(to_process_docs) # 2. split docs into chunks, insert chunks, update doc status docs_batches = [ @@ -852,7 +851,7 @@ class LightRAG: # Check if there's a pending request to process more documents (with lock) has_pending_request = False - with storage_lock: + async with storage_lock: has_pending_request = pipeline_status.get("request_pending", False) if has_pending_request: # Clear the request flag before checking for more documents @@ -867,13 +866,13 @@ class LightRAG: pipeline_status["history_messages"].append(log_message) finally: - # Always reset busy status when done or if an exception occurs (with lock) - with storage_lock: - pipeline_status["busy"] = False log_message = "Document processing pipeline completed" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Always reset busy status when done or if an exception occurs (with lock) + async with storage_lock: + pipeline_status["busy"] = False + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: @@ -911,7 +910,7 @@ class LightRAG: # 获取 pipeline_status 并更新 latest_message 和 history_messages from lightrag.kg.shared_storage import get_namespace_data - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message)