Fix linting

This commit is contained in:
yangdx
2025-03-01 16:23:34 +08:00
parent 3507e894d9
commit e3a40c2fdb
7 changed files with 138 additions and 95 deletions

View File

@@ -73,9 +73,12 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
with self._storage_lock: with self._storage_lock:
# Check if storage was updated by another process # Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
@@ -86,7 +89,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
self.storage_updated = False self.storage_updated = False
return self._index return self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -337,32 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...") logger.warning(
with self._storage_lock: f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
self._index = faiss.IndexFlatIP(self._dim) )
self._id_to_meta = {} with self._storage_lock:
self._load_faiss_index() self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Set all update flags to False
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error return False # Return error
# Acquire lock and perform persistence return True # Return success
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Set all update flags to False
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error
return True # Return success

View File

@@ -64,9 +64,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
async with self._storage_lock: async with self._storage_lock:
# Check if data needs to be reloaded # Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
@@ -204,7 +207,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...") logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
storage_file=self._client_file_name, storage_file=self._client_file_name,

View File

@@ -108,11 +108,16 @@ class NetworkXStorage(BaseGraphStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
async with self._storage_lock: async with self._storage_lock:
# Check if data needs to be reloaded # Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag # Reset update flag
if is_multiprocess: if is_multiprocess:
self.storage_updated.value = False self.storage_updated.value = False
@@ -121,7 +126,6 @@ class NetworkXStorage(BaseGraphStorage):
return self._graph return self._graph
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
graph = await self._get_graph() graph = await self._get_graph()
return graph.has_node(node_id) return graph.has_node(node_id)
@@ -334,8 +338,12 @@ class NetworkXStorage(BaseGraphStorage):
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...") logger.warning(
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() f"Graph for {self.namespace} was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag # Reset update flag
self.storage_updated.value = False self.storage_updated.value = False
return False # Return error return False # Return error

View File

@@ -38,8 +38,8 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"): if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
import asyncpg # type: ignore import asyncpg # type: ignore
from asyncpg import Pool # type: ignore from asyncpg import Pool # type: ignore
class PostgreSQLDB: class PostgreSQLDB:

View File

@@ -15,7 +15,7 @@ def direct_log(message, level="INFO"):
print(f"{level}: {message}", file=sys.stderr, flush=True) print(f"{level}: {message}", file=sys.stderr, flush=True)
T = TypeVar('T') T = TypeVar("T")
LockType = Union[ProcessLock, asyncio.Lock] LockType = Union[ProcessLock, asyncio.Lock]
is_multiprocess = None is_multiprocess = None
@@ -26,20 +26,22 @@ _initialized = None
# shared data for storage across processes # shared data for storage across processes
_shared_dicts: Optional[Dict[str, Any]] = None _shared_dicts: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
# locks for mutex access # locks for mutex access
_storage_lock: Optional[LockType] = None _storage_lock: Optional[LockType] = None
_internal_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None
class UnifiedLock(Generic[T]): class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
self._lock = lock self._lock = lock
self._is_async = is_async self._is_async = is_async
async def __aenter__(self) -> 'UnifiedLock[T]': async def __aenter__(self) -> "UnifiedLock[T]":
if self._is_async: if self._is_async:
await self._lock.acquire() await self._lock.acquire()
else: else:
@@ -52,7 +54,7 @@ class UnifiedLock(Generic[T]):
else: else:
self._lock.release() self._lock.release()
def __enter__(self) -> 'UnifiedLock[T]': def __enter__(self) -> "UnifiedLock[T]":
"""For backward compatibility""" """For backward compatibility"""
if self._is_async: if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock") raise RuntimeError("Use 'async with' for shared_storage lock")
@@ -68,24 +70,18 @@ class UnifiedLock(Generic[T]):
def get_internal_lock() -> UnifiedLock: def get_internal_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
lock=_internal_lock,
is_async=not is_multiprocess
)
def get_storage_lock() -> UnifiedLock: def get_storage_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
lock=_storage_lock,
is_async=not is_multiprocess
)
def get_pipeline_status_lock() -> UnifiedLock: def get_pipeline_status_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
lock=_pipeline_status_lock,
is_async=not is_multiprocess
)
def initialize_share_data(workers: int = 1): def initialize_share_data(workers: int = 1):
""" """
@@ -166,17 +162,19 @@ async def initialize_pipeline_status():
# Create a shared list object for history_messages # Create a shared list object for history_messages
history_messages = _manager.list() if is_multiprocess else [] history_messages = _manager.list() if is_multiprocess else []
pipeline_namespace.update({ pipeline_namespace.update(
"busy": False, # Control concurrent processes {
"job_name": "Default Job", # Current job name (indexing files/indexing texts) "busy": False, # Control concurrent processes
"job_start": None, # Job start time "job_name": "Default Job", # Current job name (indexing files/indexing texts)
"docs": 0, # Total number of documents to be indexed "job_start": None, # Job start time
"batchs": 0, # Number of batches for processing documents "docs": 0, # Total number of documents to be indexed
"cur_batch": 0, # Current processing batch "batchs": 0, # Number of batches for processing documents
"request_pending": False, # Flag for pending request for processing "cur_batch": 0, # Current processing batch
"latest_message": "", # Latest message from pipeline processing "request_pending": False, # Flag for pending request for processing
"history_messages": history_messages, # 使用共享列表对象 "latest_message": "", # Latest message from pipeline processing
}) "history_messages": history_messages, # 使用共享列表对象
}
)
direct_log(f"Process {os.getpid()} Pipeline namespace initialized") direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
@@ -195,16 +193,19 @@ async def get_update_flag(namespace: str):
_update_flags[namespace] = _manager.list() _update_flags[namespace] = _manager.list()
else: else:
_update_flags[namespace] = [] _update_flags[namespace] = []
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") direct_log(
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
)
if is_multiprocess and _manager is not None: if is_multiprocess and _manager is not None:
new_update_flag = _manager.Value('b', False) new_update_flag = _manager.Value("b", False)
else: else:
new_update_flag = False new_update_flag = False
_update_flags[namespace].append(new_update_flag) _update_flags[namespace].append(new_update_flag)
return new_update_flag return new_update_flag
async def set_all_update_flags(namespace: str): async def set_all_update_flags(namespace: str):
"""Set all update flag of namespace indicating all workers need to reload data from files""" """Set all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags global _update_flags

View File

@@ -696,7 +696,10 @@ class LightRAG:
3. Process each chunk for entity and relation extraction 3. Process each chunk for entity and relation extraction
4. Update the document status 4. Update the document status
""" """
from lightrag.kg.shared_storage import get_namespace_data, get_pipeline_status_lock from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
)
# Get pipeline status shared data and lock # Get pipeline status shared data and lock
pipeline_status = await get_namespace_data("pipeline_status") pipeline_status = await get_namespace_data("pipeline_status")

View File

@@ -47,6 +47,7 @@ def main():
# Check and install gunicorn if not present # Check and install gunicorn if not present
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("gunicorn"): if not pm.is_installed("gunicorn"):
print("Installing gunicorn...") print("Installing gunicorn...")
pm.install("gunicorn") pm.install("gunicorn")
@@ -103,7 +104,9 @@ def main():
import gunicorn_config import gunicorn_config
# Set configuration variables in gunicorn_config, prioritizing command line arguments # Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = args.workers if args.workers else int(os.getenv("WORKERS", 1)) gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments # Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
@@ -111,18 +114,36 @@ def main():
gunicorn_config.bind = f"{host}:{port}" gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments # Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments # Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
)
# Keepalive configuration # Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments # SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in ("true", "1", "yes", "t", "on"): if args.ssl or os.getenv("SSL", "").lower() in (
gunicorn_config.certfile = args.ssl_certfile if args.ssl_certfile else os.getenv("SSL_CERTFILE") "true",
gunicorn_config.keyfile = args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") "1",
"yes",
"t",
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module # Set configuration options from the module
for key in dir(gunicorn_config): for key in dir(gunicorn_config):