refactor: migrate synchronous locks to async locks for improved concurrency
• Add UnifiedLock wrapper class • Convert with blocks to async with
This commit is contained in:
@@ -143,13 +143,10 @@ def create_app(args):
|
|||||||
get_storage_lock,
|
get_storage_lock,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get pipeline status and lock
|
|
||||||
pipeline_status = get_namespace_data("pipeline_status")
|
|
||||||
storage_lock = get_storage_lock()
|
|
||||||
|
|
||||||
# Check if a task is already running (with lock protection)
|
# Check if a task is already running (with lock protection)
|
||||||
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
should_start_task = False
|
should_start_task = False
|
||||||
with storage_lock:
|
async with get_storage_lock():
|
||||||
if not pipeline_status.get("busy", False):
|
if not pipeline_status.get("busy", False):
|
||||||
should_start_task = True
|
should_start_task = True
|
||||||
# Only start the task if no other task is running
|
# Only start the task if no other task is running
|
||||||
|
@@ -24,17 +24,17 @@ from .shared_storage import (
|
|||||||
class JsonDocStatusStorage(DocStatusStorage):
|
class JsonDocStatusStorage(DocStatusStorage):
|
||||||
"""JSON implementation of document status storage"""
|
"""JSON implementation of document status storage"""
|
||||||
|
|
||||||
def __post_init__(self):
|
async def __post_init__(self):
|
||||||
working_dir = self.global_config["working_dir"]
|
working_dir = self.global_config["working_dir"]
|
||||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = get_storage_lock()
|
||||||
|
|
||||||
# check need_init must before get_namespace_data
|
# check need_init must before get_namespace_data
|
||||||
need_init = try_initialize_namespace(self.namespace)
|
need_init = try_initialize_namespace(self.namespace)
|
||||||
self._data = get_namespace_data(self.namespace)
|
self._data = await get_namespace_data(self.namespace)
|
||||||
if need_init:
|
if need_init:
|
||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(loaded_data)
|
self._data.update(loaded_data)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded document status storage with {len(loaded_data)} records"
|
f"Loaded document status storage with {len(loaded_data)} records"
|
||||||
@@ -42,12 +42,12 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return set(keys) - set(self._data.keys())
|
return set(keys) - set(self._data.keys())
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
result: list[dict[str, Any]] = []
|
result: list[dict[str, Any]] = []
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for id in ids:
|
for id in ids:
|
||||||
data = self._data.get(id, None)
|
data = self._data.get(id, None)
|
||||||
if data:
|
if data:
|
||||||
@@ -57,7 +57,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
counts = {status.value: 0 for status in DocStatus}
|
counts = {status.value: 0 for status in DocStatus}
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for doc in self._data.values():
|
for doc in self._data.values():
|
||||||
counts[doc["status"]] += 1
|
counts[doc["status"]] += 1
|
||||||
return counts
|
return counts
|
||||||
@@ -67,7 +67,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
) -> dict[str, DocProcessingStatus]:
|
) -> dict[str, DocProcessingStatus]:
|
||||||
"""Get all documents with a specific status"""
|
"""Get all documents with a specific status"""
|
||||||
result = {}
|
result = {}
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for k, v in self._data.items():
|
for k, v in self._data.items():
|
||||||
if v["status"] == status.value:
|
if v["status"] == status.value:
|
||||||
try:
|
try:
|
||||||
@@ -83,7 +83,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
data_dict = (
|
data_dict = (
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
@@ -94,21 +94,21 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return self._data.get(id)
|
return self._data.get(id)
|
||||||
|
|
||||||
async def delete(self, doc_ids: list[str]):
|
async def delete(self, doc_ids: list[str]):
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
self._data.pop(doc_id, None)
|
self._data.pop(doc_id, None)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.clear()
|
self._data.clear()
|
||||||
|
@@ -20,33 +20,33 @@ from .shared_storage import (
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class JsonKVStorage(BaseKVStorage):
|
class JsonKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
async def __post_init__(self):
|
||||||
working_dir = self.global_config["working_dir"]
|
working_dir = self.global_config["working_dir"]
|
||||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = get_storage_lock()
|
||||||
|
|
||||||
# check need_init must before get_namespace_data
|
# check need_init must before get_namespace_data
|
||||||
need_init = try_initialize_namespace(self.namespace)
|
need_init = try_initialize_namespace(self.namespace)
|
||||||
self._data = get_namespace_data(self.namespace)
|
self._data = await get_namespace_data(self.namespace)
|
||||||
if need_init:
|
if need_init:
|
||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(loaded_data)
|
self._data.update(loaded_data)
|
||||||
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
|
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
data_dict = (
|
data_dict = (
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return self._data.get(id)
|
return self._data.get(id)
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
{k: v for k, v in self._data[id].items()}
|
{k: v for k, v in self._data[id].items()}
|
||||||
@@ -57,19 +57,19 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
return set(keys) - set(self._data.keys())
|
return set(keys) - set(self._data.keys())
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
self._data.update(left_data)
|
self._data.update(left_data)
|
||||||
|
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: list[str]) -> None:
|
||||||
with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for doc_id in ids:
|
for doc_id in ids:
|
||||||
self._data.pop(doc_id, None)
|
self._data.pop(doc_id, None)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
@@ -3,7 +3,7 @@ import sys
|
|||||||
import asyncio
|
import asyncio
|
||||||
from multiprocessing.synchronize import Lock as ProcessLock
|
from multiprocessing.synchronize import Lock as ProcessLock
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
||||||
|
|
||||||
|
|
||||||
# Define a direct print function for critical logs that must be visible in all processes
|
# Define a direct print function for critical logs that must be visible in all processes
|
||||||
@@ -15,6 +15,43 @@ def direct_log(message, level="INFO"):
|
|||||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
class UnifiedLock(Generic[T]):
|
||||||
|
"""统一的锁包装类,提供同步和异步的统一接口"""
|
||||||
|
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
||||||
|
self._lock = lock
|
||||||
|
self._is_async = is_async
|
||||||
|
|
||||||
|
async def __aenter__(self) -> 'UnifiedLock[T]':
|
||||||
|
"""异步上下文管理器入口"""
|
||||||
|
if self._is_async:
|
||||||
|
await self._lock.acquire()
|
||||||
|
else:
|
||||||
|
self._lock.acquire()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""异步上下文管理器出口"""
|
||||||
|
if self._is_async:
|
||||||
|
self._lock.release()
|
||||||
|
else:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def __enter__(self) -> 'UnifiedLock[T]':
|
||||||
|
"""同步上下文管理器入口(仅用于向后兼容)"""
|
||||||
|
if self._is_async:
|
||||||
|
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
||||||
|
self._lock.acquire()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""同步上下文管理器出口(仅用于向后兼容)"""
|
||||||
|
if self._is_async:
|
||||||
|
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
|
||||||
LockType = Union[ProcessLock, asyncio.Lock]
|
LockType = Union[ProcessLock, asyncio.Lock]
|
||||||
|
|
||||||
is_multiprocess = None
|
is_multiprocess = None
|
||||||
@@ -117,24 +154,19 @@ async def get_update_flags(namespace: str):
|
|||||||
if _update_flags is None:
|
if _update_flags is None:
|
||||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||||
|
|
||||||
if is_multiprocess:
|
async with get_storage_lock():
|
||||||
with _global_lock:
|
|
||||||
if namespace not in _update_flags:
|
if namespace not in _update_flags:
|
||||||
if _manager is not None:
|
if is_multiprocess and _manager is not None:
|
||||||
_update_flags[namespace] = _manager.list()
|
_update_flags[namespace] = _manager.list()
|
||||||
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
|
|
||||||
|
|
||||||
if _manager is not None:
|
|
||||||
new_update_flag = _manager.Value('b', False)
|
|
||||||
_update_flags[namespace].append(new_update_flag)
|
|
||||||
return new_update_flag
|
|
||||||
else:
|
else:
|
||||||
async with _global_lock:
|
|
||||||
if namespace not in _update_flags:
|
|
||||||
_update_flags[namespace] = []
|
_update_flags[namespace] = []
|
||||||
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
|
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
|
||||||
|
|
||||||
|
if is_multiprocess and _manager is not None:
|
||||||
|
new_update_flag = _manager.Value('b', False)
|
||||||
|
else:
|
||||||
new_update_flag = False
|
new_update_flag = False
|
||||||
|
|
||||||
_update_flags[namespace].append(new_update_flag)
|
_update_flags[namespace].append(new_update_flag)
|
||||||
return new_update_flag
|
return new_update_flag
|
||||||
|
|
||||||
@@ -144,19 +176,14 @@ async def set_update_flag(namespace: str):
|
|||||||
if _update_flags is None:
|
if _update_flags is None:
|
||||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||||
|
|
||||||
if is_multiprocess:
|
async with get_storage_lock():
|
||||||
with _global_lock:
|
|
||||||
if namespace not in _update_flags:
|
if namespace not in _update_flags:
|
||||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||||
# Update flags for multiprocess mode
|
# Update flags for both modes
|
||||||
for i in range(len(_update_flags[namespace])):
|
for i in range(len(_update_flags[namespace])):
|
||||||
|
if is_multiprocess:
|
||||||
_update_flags[namespace][i].value = True
|
_update_flags[namespace][i].value = True
|
||||||
else:
|
else:
|
||||||
async with _global_lock:
|
|
||||||
if namespace not in _update_flags:
|
|
||||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
|
||||||
# Update flags for single process mode
|
|
||||||
for i in range(len(_update_flags[namespace])):
|
|
||||||
_update_flags[namespace][i] = True
|
_update_flags[namespace][i] = True
|
||||||
|
|
||||||
|
|
||||||
@@ -182,9 +209,12 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_storage_lock() -> LockType:
|
def get_storage_lock() -> UnifiedLock:
|
||||||
"""return storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
return _global_lock
|
return UnifiedLock(
|
||||||
|
lock=_global_lock,
|
||||||
|
is_async=not is_multiprocess
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||||
@@ -196,14 +226,11 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
raise ValueError("Shared dictionaries not initialized")
|
raise ValueError("Shared dictionaries not initialized")
|
||||||
|
|
||||||
if is_multiprocess:
|
async with get_storage_lock():
|
||||||
with _global_lock:
|
|
||||||
if namespace not in _shared_dicts:
|
if namespace not in _shared_dicts:
|
||||||
if _manager is not None:
|
if is_multiprocess and _manager is not None:
|
||||||
_shared_dicts[namespace] = _manager.dict()
|
_shared_dicts[namespace] = _manager.dict()
|
||||||
else:
|
else:
|
||||||
async with _global_lock:
|
|
||||||
if namespace not in _shared_dicts:
|
|
||||||
_shared_dicts[namespace] = {}
|
_shared_dicts[namespace] = {}
|
||||||
|
|
||||||
return _shared_dicts[namespace]
|
return _shared_dicts[namespace]
|
||||||
|
@@ -672,12 +672,12 @@ class LightRAG:
|
|||||||
from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock
|
from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock
|
||||||
|
|
||||||
# Get pipeline status shared data and lock
|
# Get pipeline status shared data and lock
|
||||||
pipeline_status = get_namespace_data("pipeline_status")
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
storage_lock = get_storage_lock()
|
storage_lock = get_storage_lock()
|
||||||
|
|
||||||
# Check if another process is already processing the queue
|
# Check if another process is already processing the queue
|
||||||
process_documents = False
|
process_documents = False
|
||||||
with storage_lock:
|
async with storage_lock:
|
||||||
# Ensure only one worker is processing documents
|
# Ensure only one worker is processing documents
|
||||||
if not pipeline_status.get("busy", False):
|
if not pipeline_status.get("busy", False):
|
||||||
# Cleaning history_messages without breaking it as a shared list object
|
# Cleaning history_messages without breaking it as a shared list object
|
||||||
@@ -732,7 +732,6 @@ class LightRAG:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Update pipeline status with document count (with lock)
|
# Update pipeline status with document count (with lock)
|
||||||
with storage_lock:
|
|
||||||
pipeline_status["docs"] = len(to_process_docs)
|
pipeline_status["docs"] = len(to_process_docs)
|
||||||
|
|
||||||
# 2. split docs into chunks, insert chunks, update doc status
|
# 2. split docs into chunks, insert chunks, update doc status
|
||||||
@@ -852,7 +851,7 @@ class LightRAG:
|
|||||||
|
|
||||||
# Check if there's a pending request to process more documents (with lock)
|
# Check if there's a pending request to process more documents (with lock)
|
||||||
has_pending_request = False
|
has_pending_request = False
|
||||||
with storage_lock:
|
async with storage_lock:
|
||||||
has_pending_request = pipeline_status.get("request_pending", False)
|
has_pending_request = pipeline_status.get("request_pending", False)
|
||||||
if has_pending_request:
|
if has_pending_request:
|
||||||
# Clear the request flag before checking for more documents
|
# Clear the request flag before checking for more documents
|
||||||
@@ -867,11 +866,11 @@ class LightRAG:
|
|||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Always reset busy status when done or if an exception occurs (with lock)
|
|
||||||
with storage_lock:
|
|
||||||
pipeline_status["busy"] = False
|
|
||||||
log_message = "Document processing pipeline completed"
|
log_message = "Document processing pipeline completed"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
# Always reset busy status when done or if an exception occurs (with lock)
|
||||||
|
async with storage_lock:
|
||||||
|
pipeline_status["busy"] = False
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
@@ -911,7 +910,7 @@ class LightRAG:
|
|||||||
# 获取 pipeline_status 并更新 latest_message 和 history_messages
|
# 获取 pipeline_status 并更新 latest_message 和 history_messages
|
||||||
from lightrag.kg.shared_storage import get_namespace_data
|
from lightrag.kg.shared_storage import get_namespace_data
|
||||||
|
|
||||||
pipeline_status = get_namespace_data("pipeline_status")
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user