Fix linting

This commit is contained in:
yangdx
2025-03-01 16:23:34 +08:00
parent 3507e894d9
commit e3a40c2fdb
7 changed files with 138 additions and 95 deletions

View File

@@ -73,9 +73,12 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
with self._storage_lock: with self._storage_lock:
# Check if storage was updated by another process # Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
@@ -86,7 +89,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
self.storage_updated = False self.storage_updated = False
return self._index return self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -337,11 +339,14 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...") logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
with self._storage_lock: with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}

View File

@@ -64,9 +64,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
async with self._storage_lock: async with self._storage_lock:
# Check if data needs to be reloaded # Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
@@ -204,7 +207,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...") logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
storage_file=self._client_file_name, storage_file=self._client_file_name,

View File

@@ -108,11 +108,16 @@ class NetworkXStorage(BaseGraphStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
async with self._storage_lock: async with self._storage_lock:
# Check if data needs to be reloaded # Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag # Reset update flag
if is_multiprocess: if is_multiprocess:
self.storage_updated.value = False self.storage_updated.value = False
@@ -121,7 +126,6 @@ class NetworkXStorage(BaseGraphStorage):
return self._graph return self._graph
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
graph = await self._get_graph() graph = await self._get_graph()
return graph.has_node(node_id) return graph.has_node(node_id)
@@ -334,8 +338,12 @@ class NetworkXStorage(BaseGraphStorage):
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...") logger.warning(
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() f"Graph for {self.namespace} was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag # Reset update flag
self.storage_updated.value = False self.storage_updated.value = False
return False # Return error return False # Return error

View File

@@ -15,7 +15,7 @@ def direct_log(message, level="INFO"):
print(f"{level}: {message}", file=sys.stderr, flush=True) print(f"{level}: {message}", file=sys.stderr, flush=True)
T = TypeVar('T') T = TypeVar("T")
LockType = Union[ProcessLock, asyncio.Lock] LockType = Union[ProcessLock, asyncio.Lock]
is_multiprocess = None is_multiprocess = None
@@ -33,13 +33,15 @@ _storage_lock: Optional[LockType] = None
_internal_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None
class UnifiedLock(Generic[T]): class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
self._lock = lock self._lock = lock
self._is_async = is_async self._is_async = is_async
async def __aenter__(self) -> 'UnifiedLock[T]': async def __aenter__(self) -> "UnifiedLock[T]":
if self._is_async: if self._is_async:
await self._lock.acquire() await self._lock.acquire()
else: else:
@@ -52,7 +54,7 @@ class UnifiedLock(Generic[T]):
else: else:
self._lock.release() self._lock.release()
def __enter__(self) -> 'UnifiedLock[T]': def __enter__(self) -> "UnifiedLock[T]":
"""For backward compatibility""" """For backward compatibility"""
if self._is_async: if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock") raise RuntimeError("Use 'async with' for shared_storage lock")
@@ -68,24 +70,18 @@ class UnifiedLock(Generic[T]):
def get_internal_lock() -> UnifiedLock: def get_internal_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
lock=_internal_lock,
is_async=not is_multiprocess
)
def get_storage_lock() -> UnifiedLock: def get_storage_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
lock=_storage_lock,
is_async=not is_multiprocess
)
def get_pipeline_status_lock() -> UnifiedLock: def get_pipeline_status_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
lock=_pipeline_status_lock,
is_async=not is_multiprocess
)
def initialize_share_data(workers: int = 1): def initialize_share_data(workers: int = 1):
""" """
@@ -166,7 +162,8 @@ async def initialize_pipeline_status():
# Create a shared list object for history_messages # Create a shared list object for history_messages
history_messages = _manager.list() if is_multiprocess else [] history_messages = _manager.list() if is_multiprocess else []
pipeline_namespace.update({ pipeline_namespace.update(
{
"busy": False, # Control concurrent processes "busy": False, # Control concurrent processes
"job_name": "Default Job", # Current job name (indexing files/indexing texts) "job_name": "Default Job", # Current job name (indexing files/indexing texts)
"job_start": None, # Job start time "job_start": None, # Job start time
@@ -176,7 +173,8 @@ async def initialize_pipeline_status():
"request_pending": False, # Flag for pending request for processing "request_pending": False, # Flag for pending request for processing
"latest_message": "", # Latest message from pipeline processing "latest_message": "", # Latest message from pipeline processing
"history_messages": history_messages, # 使用共享列表对象 "history_messages": history_messages, # 使用共享列表对象
}) }
)
direct_log(f"Process {os.getpid()} Pipeline namespace initialized") direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
@@ -195,16 +193,19 @@ async def get_update_flag(namespace: str):
_update_flags[namespace] = _manager.list() _update_flags[namespace] = _manager.list()
else: else:
_update_flags[namespace] = [] _update_flags[namespace] = []
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") direct_log(
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
)
if is_multiprocess and _manager is not None: if is_multiprocess and _manager is not None:
new_update_flag = _manager.Value('b', False) new_update_flag = _manager.Value("b", False)
else: else:
new_update_flag = False new_update_flag = False
_update_flags[namespace].append(new_update_flag) _update_flags[namespace].append(new_update_flag)
return new_update_flag return new_update_flag
async def set_all_update_flags(namespace: str): async def set_all_update_flags(namespace: str):
"""Set all update flag of namespace indicating all workers need to reload data from files""" """Set all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags global _update_flags

View File

@@ -696,7 +696,10 @@ class LightRAG:
3. Process each chunk for entity and relation extraction 3. Process each chunk for entity and relation extraction
4. Update the document status 4. Update the document status
""" """
from lightrag.kg.shared_storage import get_namespace_data, get_pipeline_status_lock from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
)
# Get pipeline status shared data and lock # Get pipeline status shared data and lock
pipeline_status = await get_namespace_data("pipeline_status") pipeline_status = await get_namespace_data("pipeline_status")

View File

@@ -47,6 +47,7 @@ def main():
# Check and install gunicorn if not present # Check and install gunicorn if not present
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("gunicorn"): if not pm.is_installed("gunicorn"):
print("Installing gunicorn...") print("Installing gunicorn...")
pm.install("gunicorn") pm.install("gunicorn")
@@ -103,7 +104,9 @@ def main():
import gunicorn_config import gunicorn_config
# Set configuration variables in gunicorn_config, prioritizing command line arguments # Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = args.workers if args.workers else int(os.getenv("WORKERS", 1)) gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments # Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
@@ -111,18 +114,36 @@ def main():
gunicorn_config.bind = f"{host}:{port}" gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments # Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments # Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
)
# Keepalive configuration # Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments # SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in ("true", "1", "yes", "t", "on"): if args.ssl or os.getenv("SSL", "").lower() in (
gunicorn_config.certfile = args.ssl_certfile if args.ssl_certfile else os.getenv("SSL_CERTFILE") "true",
gunicorn_config.keyfile = args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") "1",
"yes",
"t",
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module # Set configuration options from the module
for key in dir(gunicorn_config): for key in dir(gunicorn_config):