From e3a40c2fdbc041e176c642ece9e576385d2b0502 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 16:23:34 +0800 Subject: [PATCH] Fix linting --- lightrag/kg/faiss_impl.py | 65 ++++++++++++++------------- lightrag/kg/nano_vector_db_impl.py | 17 ++++--- lightrag/kg/networkx_impl.py | 30 ++++++++----- lightrag/kg/postgres_impl.py | 4 +- lightrag/kg/shared_storage.py | 71 +++++++++++++++--------------- lightrag/lightrag.py | 5 ++- run_with_gunicorn.py | 41 ++++++++++++----- 7 files changed, 138 insertions(+), 95 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index f244c288..bb4d47ec 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -50,7 +50,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # If you have a large number of vectors, you might want IVF or other indexes. # For demonstration, we use a simple IndexFlatIP. @@ -73,9 +73,12 @@ class FaissVectorDBStorage(BaseVectorStorage): # Acquire lock to prevent concurrent read and write with self._storage_lock: # Check if storage was updated by another process - if (is_multiprocess and self.storage_updated.value) or \ - (not is_multiprocess and self.storage_updated): - logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process") + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process" + ) # Reload data self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} @@ -86,7 +89,6 @@ class FaissVectorDBStorage(BaseVectorStorage): self.storage_updated = False return self._index - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -337,32 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage): self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} + async def index_done_callback(self) -> None: - # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: - # Storage was updated by another process, reload data instead of saving - logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...") - with self._storage_lock: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta = {} - self._load_faiss_index() + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for FAISS {self.namespace} was updated by another process, reloading..." + ) + with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._save_faiss_index() + # Set all update flags to False + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: self.storage_updated.value = False + else: + self.storage_updated = False + except Exception as e: + logger.error(f"Error saving FAISS index for {self.namespace}: {e}") return False # Return error - # Acquire lock and perform persistence - async with self._storage_lock: - try: - # Save data to disk - self._save_faiss_index() - # Set all update flags to False - await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-reloading - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False - except Exception as e: - logger.error(f"Error saving FAISS index for {self.namespace}: {e}") - return False # Return error - - return True # Return success + return True # Return success diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index e0ecacdf..07c800de 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -64,9 +64,12 @@ class NanoVectorDBStorage(BaseVectorStorage): # Acquire lock to prevent concurrent read and write async with self._storage_lock: # Check if data needs to be reloaded - if (is_multiprocess and self.storage_updated.value) or \ - (not is_multiprocess and self.storage_updated): - logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process") + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} reloading {self.namespace} due to update by another process" + ) # Reload data self._client = NanoVectorDB( self.embedding_func.embedding_dim, @@ -77,7 +80,7 @@ class NanoVectorDBStorage(BaseVectorStorage): self.storage_updated.value = False else: self.storage_updated = False - + return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: @@ -204,7 +207,9 @@ class NanoVectorDBStorage(BaseVectorStorage): # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving - logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...") + logger.warning( + f"Storage for {self.namespace} was updated by another process, reloading..." + ) self._client = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name, @@ -212,7 +217,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Reset update flag self.storage_updated.value = False return False # Return error - + # Acquire lock and perform persistence async with self._storage_lock: try: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 2e61e6b3..f11e9c0e 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -91,7 +91,7 @@ class NetworkXStorage(BaseGraphStorage): else: logger.info("Created new empty graph") self._graph = preloaded_graph or nx.Graph() - + self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } @@ -108,19 +108,23 @@ class NetworkXStorage(BaseGraphStorage): # Acquire lock to prevent concurrent read and write async with self._storage_lock: # Check if data needs to be reloaded - if (is_multiprocess and self.storage_updated.value) or \ - (not is_multiprocess and self.storage_updated): - logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process") + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process" + ) # Reload data - self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + self._graph = ( + NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + ) # Reset update flag if is_multiprocess: self.storage_updated.value = False else: self.storage_updated = False - - return self._graph + return self._graph async def has_node(self, node_id: str) -> bool: graph = await self._get_graph() @@ -334,12 +338,16 @@ class NetworkXStorage(BaseGraphStorage): # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving - logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...") - self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + logger.warning( + f"Graph for {self.namespace} was updated by another process, reloading..." + ) + self._graph = ( + NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + ) # Reset update flag self.storage_updated.value = False return False # Return error - + # Acquire lock and perform persistence async with self._storage_lock: try: @@ -356,5 +364,5 @@ class NetworkXStorage(BaseGraphStorage): except Exception as e: logger.error(f"Error saving graph for {self.namespace}: {e}") return False # Return error - + return True diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 10883a88..51044be5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -38,8 +38,8 @@ import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") -import asyncpg # type: ignore -from asyncpg import Pool # type: ignore +import asyncpg # type: ignore +from asyncpg import Pool # type: ignore class PostgreSQLDB: diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 27d23f2e..acebafa7 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -15,7 +15,7 @@ def direct_log(message, level="INFO"): print(f"{level}: {message}", file=sys.stderr, flush=True) -T = TypeVar('T') +T = TypeVar("T") LockType = Union[ProcessLock, asyncio.Lock] is_multiprocess = None @@ -26,20 +26,22 @@ _initialized = None # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized -_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated +_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated # locks for mutex access _storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None + class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" + def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): self._lock = lock self._is_async = is_async - async def __aenter__(self) -> 'UnifiedLock[T]': + async def __aenter__(self) -> "UnifiedLock[T]": if self._is_async: await self._lock.acquire() else: @@ -52,7 +54,7 @@ class UnifiedLock(Generic[T]): else: self._lock.release() - def __enter__(self) -> 'UnifiedLock[T]': + def __enter__(self) -> "UnifiedLock[T]": """For backward compatibility""" if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") @@ -68,24 +70,18 @@ class UnifiedLock(Generic[T]): def get_internal_lock() -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock( - lock=_internal_lock, - is_async=not is_multiprocess - ) + return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess) + def get_storage_lock() -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock( - lock=_storage_lock, - is_async=not is_multiprocess - ) + return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess) + def get_pipeline_status_lock() -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock( - lock=_pipeline_status_lock, - is_async=not is_multiprocess - ) + return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess) + def initialize_share_data(workers: int = 1): """ @@ -166,17 +162,19 @@ async def initialize_pipeline_status(): # Create a shared list object for history_messages history_messages = _manager.list() if is_multiprocess else [] - pipeline_namespace.update({ - "busy": False, # Control concurrent processes - "job_name": "Default Job", # Current job name (indexing files/indexing texts) - "job_start": None, # Job start time - "docs": 0, # Total number of documents to be indexed - "batchs": 0, # Number of batches for processing documents - "cur_batch": 0, # Current processing batch - "request_pending": False, # Flag for pending request for processing - "latest_message": "", # Latest message from pipeline processing - "history_messages": history_messages, # 使用共享列表对象 - }) + pipeline_namespace.update( + { + "busy": False, # Control concurrent processes + "job_name": "Default Job", # Current job name (indexing files/indexing texts) + "job_start": None, # Job start time + "docs": 0, # Total number of documents to be indexed + "batchs": 0, # Number of batches for processing documents + "cur_batch": 0, # Current processing batch + "request_pending": False, # Flag for pending request for processing + "latest_message": "", # Latest message from pipeline processing + "history_messages": history_messages, # 使用共享列表对象 + } + ) direct_log(f"Process {os.getpid()} Pipeline namespace initialized") @@ -195,22 +193,25 @@ async def get_update_flag(namespace: str): _update_flags[namespace] = _manager.list() else: _update_flags[namespace] = [] - direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") - + direct_log( + f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" + ) + if is_multiprocess and _manager is not None: - new_update_flag = _manager.Value('b', False) + new_update_flag = _manager.Value("b", False) else: new_update_flag = False - + _update_flags[namespace].append(new_update_flag) return new_update_flag + async def set_all_update_flags(namespace: str): """Set all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - + async with get_internal_lock(): if namespace not in _update_flags: raise ValueError(f"Namespace {namespace} not found in update flags") @@ -225,13 +226,13 @@ async def set_all_update_flags(namespace: str): async def get_all_update_flags_status() -> Dict[str, list]: """ Get update flags status for all namespaces. - + Returns: Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses """ if _update_flags is None: return {} - + result = {} async with get_internal_lock(): for namespace, flags in _update_flags.items(): @@ -242,7 +243,7 @@ async def get_all_update_flags_status() -> Dict[str, list]: else: worker_statuses.append(flag) result[namespace] = worker_statuses - + return result diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6008b39c..44b77ae7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -696,7 +696,10 @@ class LightRAG: 3. Process each chunk for entity and relation extraction 4. Update the document status """ - from lightrag.kg.shared_storage import get_namespace_data, get_pipeline_status_lock + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_pipeline_status_lock, + ) # Get pipeline status shared data and lock pipeline_status = await get_namespace_data("pipeline_status") diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index decd91de..644e6e87 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -47,10 +47,11 @@ def main(): # Check and install gunicorn if not present import pipmaster as pm + if not pm.is_installed("gunicorn"): print("Installing gunicorn...") pm.install("gunicorn") - + # Import Gunicorn's StandaloneApplication from gunicorn.app.base import BaseApplication @@ -103,26 +104,46 @@ def main(): import gunicorn_config # Set configuration variables in gunicorn_config, prioritizing command line arguments - gunicorn_config.workers = args.workers if args.workers else int(os.getenv("WORKERS", 1)) - + gunicorn_config.workers = ( + args.workers if args.workers else int(os.getenv("WORKERS", 1)) + ) + # Bind configuration prioritizes command line arguments host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) gunicorn_config.bind = f"{host}:{port}" - + # Log level configuration prioritizes command line arguments - gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") + gunicorn_config.loglevel = ( + args.log_level.lower() + if args.log_level + else os.getenv("LOG_LEVEL", "info") + ) # Timeout configuration prioritizes command line arguments - gunicorn_config.timeout = args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) - + gunicorn_config.timeout = ( + args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) + ) + # Keepalive configuration gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ("true", "1", "yes", "t", "on"): - gunicorn_config.certfile = args.ssl_certfile if args.ssl_certfile else os.getenv("SSL_CERTFILE") - gunicorn_config.keyfile = args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + if args.ssl or os.getenv("SSL", "").lower() in ( + "true", + "1", + "yes", + "t", + "on", + ): + gunicorn_config.certfile = ( + args.ssl_certfile + if args.ssl_certfile + else os.getenv("SSL_CERTFILE") + ) + gunicorn_config.keyfile = ( + args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + ) # Set configuration options from the module for key in dir(gunicorn_config):