Refactor shared storage module to improve async handling and naming consistency
• Add async support for get_namespace_data • Rename get_update_flags to get_update_flag • Rename set_update_flag to set_all_update_flags • Update docstrings for clarity • Fix typos in log messages
This commit is contained in:
@@ -667,7 +667,7 @@ def create_document_routes(
|
|||||||
try:
|
try:
|
||||||
from lightrag.kg.shared_storage import get_namespace_data
|
from lightrag.kg.shared_storage import get_namespace_data
|
||||||
|
|
||||||
pipeline_status = get_namespace_data("pipeline_status")
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
|
|
||||||
# Convert to regular dict if it's a Manager.dict
|
# Convert to regular dict if it's a Manager.dict
|
||||||
status_dict = dict(pipeline_status)
|
status_dict = dict(pipeline_status)
|
||||||
|
@@ -18,13 +18,12 @@ def direct_log(message, level="INFO"):
|
|||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
class UnifiedLock(Generic[T]):
|
class UnifiedLock(Generic[T]):
|
||||||
"""统一的锁包装类,提供同步和异步的统一接口"""
|
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||||
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
self._is_async = is_async
|
self._is_async = is_async
|
||||||
|
|
||||||
async def __aenter__(self) -> 'UnifiedLock[T]':
|
async def __aenter__(self) -> 'UnifiedLock[T]':
|
||||||
"""异步上下文管理器入口"""
|
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
await self._lock.acquire()
|
await self._lock.acquire()
|
||||||
else:
|
else:
|
||||||
@@ -32,21 +31,20 @@ class UnifiedLock(Generic[T]):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""异步上下文管理器出口"""
|
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
else:
|
else:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def __enter__(self) -> 'UnifiedLock[T]':
|
def __enter__(self) -> 'UnifiedLock[T]':
|
||||||
"""同步上下文管理器入口(仅用于向后兼容)"""
|
"""For backward compatibility"""
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""同步上下文管理器出口(仅用于向后兼容)"""
|
"""For backward compatibility"""
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@@ -153,10 +151,10 @@ async def initialize_pipeline_namespace():
|
|||||||
direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
|
direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
|
||||||
|
|
||||||
|
|
||||||
async def get_update_flags(namespace: str):
|
async def get_update_flag(namespace: str):
|
||||||
"""
|
"""
|
||||||
Create a updated flags of a specific namespace.
|
Create a namespace's update flag for a workers.
|
||||||
Caller must store the dict object locally and use it to determine when to update the storage.
|
Returen the update flag to caller for referencing or reset.
|
||||||
"""
|
"""
|
||||||
global _update_flags
|
global _update_flags
|
||||||
if _update_flags is None:
|
if _update_flags is None:
|
||||||
@@ -178,8 +176,8 @@ async def get_update_flags(namespace: str):
|
|||||||
_update_flags[namespace].append(new_update_flag)
|
_update_flags[namespace].append(new_update_flag)
|
||||||
return new_update_flag
|
return new_update_flag
|
||||||
|
|
||||||
async def set_update_flag(namespace: str):
|
async def set_all_update_flags(namespace: str):
|
||||||
"""Set all update flag of namespace to indicate storage needs updating"""
|
"""Set all update flag of namespace indicating all workers need to reload data from files"""
|
||||||
global _update_flags
|
global _update_flags
|
||||||
if _update_flags is None:
|
if _update_flags is None:
|
||||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||||
@@ -212,7 +210,7 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
direct_log(
|
direct_log(
|
||||||
f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]"
|
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@@ -338,7 +338,7 @@ async def extract_entities(
|
|||||||
) -> None:
|
) -> None:
|
||||||
from lightrag.kg.shared_storage import get_namespace_data
|
from lightrag.kg.shared_storage import get_namespace_data
|
||||||
|
|
||||||
pipeline_status = get_namespace_data("pipeline_status")
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||||
|
Reference in New Issue
Block a user