refactor: migrate synchronous locks to async locks for improved concurrency

• Add UnifiedLock wrapper class
• Convert with blocks to async with
This commit is contained in:
yangdx
2025-03-01 02:22:35 +08:00
parent a721421bd8
commit b3328542c7
5 changed files with 102 additions and 79 deletions

View File

@@ -143,13 +143,10 @@ def create_app(args):
get_storage_lock, get_storage_lock,
) )
# Get pipeline status and lock
pipeline_status = get_namespace_data("pipeline_status")
storage_lock = get_storage_lock()
# Check if a task is already running (with lock protection) # Check if a task is already running (with lock protection)
pipeline_status = await get_namespace_data("pipeline_status")
should_start_task = False should_start_task = False
with storage_lock: async with get_storage_lock():
if not pipeline_status.get("busy", False): if not pipeline_status.get("busy", False):
should_start_task = True should_start_task = True
# Only start the task if no other task is running # Only start the task if no other task is running

View File

@@ -24,17 +24,17 @@ from .shared_storage import (
class JsonDocStatusStorage(DocStatusStorage): class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage""" """JSON implementation of document status storage"""
def __post_init__(self): async def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_data # check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace) need_init = try_initialize_namespace(self.namespace)
self._data = get_namespace_data(self.namespace) self._data = await get_namespace_data(self.namespace)
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
with self._storage_lock: async with self._storage_lock:
self._data.update(loaded_data) self._data.update(loaded_data)
logger.info( logger.info(
f"Loaded document status storage with {len(loaded_data)} records" f"Loaded document status storage with {len(loaded_data)} records"
@@ -42,12 +42,12 @@ class JsonDocStatusStorage(DocStatusStorage):
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)""" """Return keys that should be processed (not in storage or not successfully processed)"""
with self._storage_lock: async with self._storage_lock:
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] result: list[dict[str, Any]] = []
with self._storage_lock: async with self._storage_lock:
for id in ids: for id in ids:
data = self._data.get(id, None) data = self._data.get(id, None)
if data: if data:
@@ -57,7 +57,7 @@ class JsonDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus} counts = {status.value: 0 for status in DocStatus}
with self._storage_lock: async with self._storage_lock:
for doc in self._data.values(): for doc in self._data.values():
counts[doc["status"]] += 1 counts[doc["status"]] += 1
return counts return counts
@@ -67,7 +67,7 @@ class JsonDocStatusStorage(DocStatusStorage):
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
result = {} result = {}
with self._storage_lock: async with self._storage_lock:
for k, v in self._data.items(): for k, v in self._data.items():
if v["status"] == status.value: if v["status"] == status.value:
try: try:
@@ -83,7 +83,7 @@ class JsonDocStatusStorage(DocStatusStorage):
return result return result
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
with self._storage_lock: async with self._storage_lock:
data_dict = ( data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
@@ -94,21 +94,21 @@ class JsonDocStatusStorage(DocStatusStorage):
if not data: if not data:
return return
with self._storage_lock: async with self._storage_lock:
self._data.update(data) self._data.update(data)
await self.index_done_callback() await self.index_done_callback()
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
with self._storage_lock: async with self._storage_lock:
return self._data.get(id) return self._data.get(id)
async def delete(self, doc_ids: list[str]): async def delete(self, doc_ids: list[str]):
with self._storage_lock: async with self._storage_lock:
for doc_id in doc_ids: for doc_id in doc_ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
with self._storage_lock: async with self._storage_lock:
self._data.clear() self._data.clear()

View File

@@ -20,33 +20,33 @@ from .shared_storage import (
@final @final
@dataclass @dataclass
class JsonKVStorage(BaseKVStorage): class JsonKVStorage(BaseKVStorage):
def __post_init__(self): async def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_data # check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace) need_init = try_initialize_namespace(self.namespace)
self._data = get_namespace_data(self.namespace) self._data = await get_namespace_data(self.namespace)
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
with self._storage_lock: async with self._storage_lock:
self._data.update(loaded_data) self._data.update(loaded_data)
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
with self._storage_lock: async with self._storage_lock:
data_dict = ( data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
with self._storage_lock: async with self._storage_lock:
return self._data.get(id) return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
with self._storage_lock: async with self._storage_lock:
return [ return [
( (
{k: v for k, v in self._data[id].items()} {k: v for k, v in self._data[id].items()}
@@ -57,19 +57,19 @@ class JsonKVStorage(BaseKVStorage):
] ]
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
with self._storage_lock: async with self._storage_lock:
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
with self._storage_lock: async with self._storage_lock:
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
with self._storage_lock: async with self._storage_lock:
for doc_id in ids: for doc_id in ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()

View File

@@ -3,7 +3,7 @@ import sys
import asyncio import asyncio
from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing.synchronize import Lock as ProcessLock
from multiprocessing import Manager from multiprocessing import Manager
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union, TypeVar, Generic
# Define a direct print function for critical logs that must be visible in all processes # Define a direct print function for critical logs that must be visible in all processes
@@ -15,6 +15,43 @@ def direct_log(message, level="INFO"):
print(f"{level}: {message}", file=sys.stderr, flush=True) print(f"{level}: {message}", file=sys.stderr, flush=True)
T = TypeVar('T')
class UnifiedLock(Generic[T]):
"""统一的锁包装类,提供同步和异步的统一接口"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
self._lock = lock
self._is_async = is_async
async def __aenter__(self) -> 'UnifiedLock[T]':
"""异步上下文管理器入口"""
if self._is_async:
await self._lock.acquire()
else:
self._lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
if self._is_async:
self._lock.release()
else:
self._lock.release()
def __enter__(self) -> 'UnifiedLock[T]':
"""同步上下文管理器入口(仅用于向后兼容)"""
if self._is_async:
raise RuntimeError("Use 'async with' for asyncio.Lock")
self._lock.acquire()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""同步上下文管理器出口(仅用于向后兼容)"""
if self._is_async:
raise RuntimeError("Use 'async with' for asyncio.Lock")
self._lock.release()
LockType = Union[ProcessLock, asyncio.Lock] LockType = Union[ProcessLock, asyncio.Lock]
is_multiprocess = None is_multiprocess = None
@@ -117,26 +154,21 @@ async def get_update_flags(namespace: str):
if _update_flags is None: if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized") raise ValueError("Try to create namespace before Shared-Data is initialized")
if is_multiprocess: async with get_storage_lock():
with _global_lock: if namespace not in _update_flags:
if namespace not in _update_flags: if is_multiprocess and _manager is not None:
if _manager is not None: _update_flags[namespace] = _manager.list()
_update_flags[namespace] = _manager.list() else:
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
if _manager is not None:
new_update_flag = _manager.Value('b', False)
_update_flags[namespace].append(new_update_flag)
return new_update_flag
else:
async with _global_lock:
if namespace not in _update_flags:
_update_flags[namespace] = [] _update_flags[namespace] = []
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
if is_multiprocess and _manager is not None:
new_update_flag = _manager.Value('b', False)
else:
new_update_flag = False new_update_flag = False
_update_flags[namespace].append(new_update_flag)
return new_update_flag _update_flags[namespace].append(new_update_flag)
return new_update_flag
async def set_update_flag(namespace: str): async def set_update_flag(namespace: str):
"""Set all update flag of namespace to indicate storage needs updating""" """Set all update flag of namespace to indicate storage needs updating"""
@@ -144,19 +176,14 @@ async def set_update_flag(namespace: str):
if _update_flags is None: if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized") raise ValueError("Try to create namespace before Shared-Data is initialized")
if is_multiprocess: async with get_storage_lock():
with _global_lock: if namespace not in _update_flags:
if namespace not in _update_flags: raise ValueError(f"Namespace {namespace} not found in update flags")
raise ValueError(f"Namespace {namespace} not found in update flags") # Update flags for both modes
# Update flags for multiprocess mode for i in range(len(_update_flags[namespace])):
for i in range(len(_update_flags[namespace])): if is_multiprocess:
_update_flags[namespace][i].value = True _update_flags[namespace][i].value = True
else: else:
async with _global_lock:
if namespace not in _update_flags:
raise ValueError(f"Namespace {namespace} not found in update flags")
# Update flags for single process mode
for i in range(len(_update_flags[namespace])):
_update_flags[namespace][i] = True _update_flags[namespace][i] = True
@@ -182,9 +209,12 @@ def try_initialize_namespace(namespace: str) -> bool:
return False return False
def get_storage_lock() -> LockType: def get_storage_lock() -> UnifiedLock:
"""return storage lock for data consistency""" """return unified storage lock for data consistency"""
return _global_lock return UnifiedLock(
lock=_global_lock,
is_async=not is_multiprocess
)
async def get_namespace_data(namespace: str) -> Dict[str, Any]: async def get_namespace_data(namespace: str) -> Dict[str, Any]:
@@ -196,14 +226,11 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
) )
raise ValueError("Shared dictionaries not initialized") raise ValueError("Shared dictionaries not initialized")
if is_multiprocess: async with get_storage_lock():
with _global_lock: if namespace not in _shared_dicts:
if namespace not in _shared_dicts: if is_multiprocess and _manager is not None:
if _manager is not None: _shared_dicts[namespace] = _manager.dict()
_shared_dicts[namespace] = _manager.dict() else:
else:
async with _global_lock:
if namespace not in _shared_dicts:
_shared_dicts[namespace] = {} _shared_dicts[namespace] = {}
return _shared_dicts[namespace] return _shared_dicts[namespace]

View File

@@ -672,12 +672,12 @@ class LightRAG:
from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock
# Get pipeline status shared data and lock # Get pipeline status shared data and lock
pipeline_status = get_namespace_data("pipeline_status") pipeline_status = await get_namespace_data("pipeline_status")
storage_lock = get_storage_lock() storage_lock = get_storage_lock()
# Check if another process is already processing the queue # Check if another process is already processing the queue
process_documents = False process_documents = False
with storage_lock: async with storage_lock:
# Ensure only one worker is processing documents # Ensure only one worker is processing documents
if not pipeline_status.get("busy", False): if not pipeline_status.get("busy", False):
# Cleaning history_messages without breaking it as a shared list object # Cleaning history_messages without breaking it as a shared list object
@@ -732,8 +732,7 @@ class LightRAG:
break break
# Update pipeline status with document count (with lock) # Update pipeline status with document count (with lock)
with storage_lock: pipeline_status["docs"] = len(to_process_docs)
pipeline_status["docs"] = len(to_process_docs)
# 2. split docs into chunks, insert chunks, update doc status # 2. split docs into chunks, insert chunks, update doc status
docs_batches = [ docs_batches = [
@@ -852,7 +851,7 @@ class LightRAG:
# Check if there's a pending request to process more documents (with lock) # Check if there's a pending request to process more documents (with lock)
has_pending_request = False has_pending_request = False
with storage_lock: async with storage_lock:
has_pending_request = pipeline_status.get("request_pending", False) has_pending_request = pipeline_status.get("request_pending", False)
if has_pending_request: if has_pending_request:
# Clear the request flag before checking for more documents # Clear the request flag before checking for more documents
@@ -867,13 +866,13 @@ class LightRAG:
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
finally: finally:
# Always reset busy status when done or if an exception occurs (with lock)
with storage_lock:
pipeline_status["busy"] = False
log_message = "Document processing pipeline completed" log_message = "Document processing pipeline completed"
logger.info(log_message) logger.info(log_message)
pipeline_status["latest_message"] = log_message # Always reset busy status when done or if an exception occurs (with lock)
pipeline_status["history_messages"].append(log_message) async with storage_lock:
pipeline_status["busy"] = False
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try: try:
@@ -911,7 +910,7 @@ class LightRAG:
# 获取 pipeline_status 并更新 latest_message 和 history_messages # 获取 pipeline_status 并更新 latest_message 和 history_messages
from lightrag.kg.shared_storage import get_namespace_data from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = get_namespace_data("pipeline_status") pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)