Fix linting

This commit is contained in:
yangdx
2025-03-01 16:23:34 +08:00
parent 3507e894d9
commit e3a40c2fdb
7 changed files with 138 additions and 95 deletions

View File

@@ -50,7 +50,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function # Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes. # If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP. # For demonstration, we use a simple IndexFlatIP.
@@ -73,9 +73,12 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
with self._storage_lock: with self._storage_lock:
# Check if storage was updated by another process # Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
@@ -86,7 +89,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
self.storage_updated = False self.storage_updated = False
return self._index return self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -337,32 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...") logger.warning(
with self._storage_lock: f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
self._index = faiss.IndexFlatIP(self._dim) )
self._id_to_meta = {} with self._storage_lock:
self._load_faiss_index() self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Set all update flags to False
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error return False # Return error
# Acquire lock and perform persistence return True # Return success
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Set all update flags to False
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error
return True # Return success

View File

@@ -64,9 +64,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
async with self._storage_lock: async with self._storage_lock:
# Check if data needs to be reloaded # Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
@@ -77,7 +80,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.storage_updated.value = False self.storage_updated.value = False
else: else:
self.storage_updated = False self.storage_updated = False
return self._client return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
@@ -204,7 +207,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...") logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
storage_file=self._client_file_name, storage_file=self._client_file_name,
@@ -212,7 +217,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Reset update flag # Reset update flag
self.storage_updated.value = False self.storage_updated.value = False
return False # Return error return False # Return error
# Acquire lock and perform persistence # Acquire lock and perform persistence
async with self._storage_lock: async with self._storage_lock:
try: try:

View File

@@ -91,7 +91,7 @@ class NetworkXStorage(BaseGraphStorage):
else: else:
logger.info("Created new empty graph") logger.info("Created new empty graph")
self._graph = preloaded_graph or nx.Graph() self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
@@ -108,19 +108,23 @@ class NetworkXStorage(BaseGraphStorage):
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
async with self._storage_lock: async with self._storage_lock:
# Check if data needs to be reloaded # Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or \ if (is_multiprocess and self.storage_updated.value) or (
(not is_multiprocess and self.storage_updated): not is_multiprocess and self.storage_updated
logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process") ):
logger.info(
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
)
# Reload data # Reload data
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag # Reset update flag
if is_multiprocess: if is_multiprocess:
self.storage_updated.value = False self.storage_updated.value = False
else: else:
self.storage_updated = False self.storage_updated = False
return self._graph
return self._graph
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
graph = await self._get_graph() graph = await self._get_graph()
@@ -334,12 +338,16 @@ class NetworkXStorage(BaseGraphStorage):
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving # Storage was updated by another process, reload data instead of saving
logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...") logger.warning(
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() f"Graph for {self.namespace} was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag # Reset update flag
self.storage_updated.value = False self.storage_updated.value = False
return False # Return error return False # Return error
# Acquire lock and perform persistence # Acquire lock and perform persistence
async with self._storage_lock: async with self._storage_lock:
try: try:
@@ -356,5 +364,5 @@ class NetworkXStorage(BaseGraphStorage):
except Exception as e: except Exception as e:
logger.error(f"Error saving graph for {self.namespace}: {e}") logger.error(f"Error saving graph for {self.namespace}: {e}")
return False # Return error return False # Return error
return True return True

View File

@@ -38,8 +38,8 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"): if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
import asyncpg # type: ignore import asyncpg # type: ignore
from asyncpg import Pool # type: ignore from asyncpg import Pool # type: ignore
class PostgreSQLDB: class PostgreSQLDB:

View File

@@ -15,7 +15,7 @@ def direct_log(message, level="INFO"):
print(f"{level}: {message}", file=sys.stderr, flush=True) print(f"{level}: {message}", file=sys.stderr, flush=True)
T = TypeVar('T') T = TypeVar("T")
LockType = Union[ProcessLock, asyncio.Lock] LockType = Union[ProcessLock, asyncio.Lock]
is_multiprocess = None is_multiprocess = None
@@ -26,20 +26,22 @@ _initialized = None
# shared data for storage across processes # shared data for storage across processes
_shared_dicts: Optional[Dict[str, Any]] = None _shared_dicts: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
# locks for mutex access # locks for mutex access
_storage_lock: Optional[LockType] = None _storage_lock: Optional[LockType] = None
_internal_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None
class UnifiedLock(Generic[T]): class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
self._lock = lock self._lock = lock
self._is_async = is_async self._is_async = is_async
async def __aenter__(self) -> 'UnifiedLock[T]': async def __aenter__(self) -> "UnifiedLock[T]":
if self._is_async: if self._is_async:
await self._lock.acquire() await self._lock.acquire()
else: else:
@@ -52,7 +54,7 @@ class UnifiedLock(Generic[T]):
else: else:
self._lock.release() self._lock.release()
def __enter__(self) -> 'UnifiedLock[T]': def __enter__(self) -> "UnifiedLock[T]":
"""For backward compatibility""" """For backward compatibility"""
if self._is_async: if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock") raise RuntimeError("Use 'async with' for shared_storage lock")
@@ -68,24 +70,18 @@ class UnifiedLock(Generic[T]):
def get_internal_lock() -> UnifiedLock: def get_internal_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
lock=_internal_lock,
is_async=not is_multiprocess
)
def get_storage_lock() -> UnifiedLock: def get_storage_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
lock=_storage_lock,
is_async=not is_multiprocess
)
def get_pipeline_status_lock() -> UnifiedLock: def get_pipeline_status_lock() -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock( return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
lock=_pipeline_status_lock,
is_async=not is_multiprocess
)
def initialize_share_data(workers: int = 1): def initialize_share_data(workers: int = 1):
""" """
@@ -166,17 +162,19 @@ async def initialize_pipeline_status():
# Create a shared list object for history_messages # Create a shared list object for history_messages
history_messages = _manager.list() if is_multiprocess else [] history_messages = _manager.list() if is_multiprocess else []
pipeline_namespace.update({ pipeline_namespace.update(
"busy": False, # Control concurrent processes {
"job_name": "Default Job", # Current job name (indexing files/indexing texts) "busy": False, # Control concurrent processes
"job_start": None, # Job start time "job_name": "Default Job", # Current job name (indexing files/indexing texts)
"docs": 0, # Total number of documents to be indexed "job_start": None, # Job start time
"batchs": 0, # Number of batches for processing documents "docs": 0, # Total number of documents to be indexed
"cur_batch": 0, # Current processing batch "batchs": 0, # Number of batches for processing documents
"request_pending": False, # Flag for pending request for processing "cur_batch": 0, # Current processing batch
"latest_message": "", # Latest message from pipeline processing "request_pending": False, # Flag for pending request for processing
"history_messages": history_messages, # 使用共享列表对象 "latest_message": "", # Latest message from pipeline processing
}) "history_messages": history_messages, # 使用共享列表对象
}
)
direct_log(f"Process {os.getpid()} Pipeline namespace initialized") direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
@@ -195,22 +193,25 @@ async def get_update_flag(namespace: str):
_update_flags[namespace] = _manager.list() _update_flags[namespace] = _manager.list()
else: else:
_update_flags[namespace] = [] _update_flags[namespace] = []
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") direct_log(
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
)
if is_multiprocess and _manager is not None: if is_multiprocess and _manager is not None:
new_update_flag = _manager.Value('b', False) new_update_flag = _manager.Value("b", False)
else: else:
new_update_flag = False new_update_flag = False
_update_flags[namespace].append(new_update_flag) _update_flags[namespace].append(new_update_flag)
return new_update_flag return new_update_flag
async def set_all_update_flags(namespace: str): async def set_all_update_flags(namespace: str):
"""Set all update flag of namespace indicating all workers need to reload data from files""" """Set all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags global _update_flags
if _update_flags is None: if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized") raise ValueError("Try to create namespace before Shared-Data is initialized")
async with get_internal_lock(): async with get_internal_lock():
if namespace not in _update_flags: if namespace not in _update_flags:
raise ValueError(f"Namespace {namespace} not found in update flags") raise ValueError(f"Namespace {namespace} not found in update flags")
@@ -225,13 +226,13 @@ async def set_all_update_flags(namespace: str):
async def get_all_update_flags_status() -> Dict[str, list]: async def get_all_update_flags_status() -> Dict[str, list]:
""" """
Get update flags status for all namespaces. Get update flags status for all namespaces.
Returns: Returns:
Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses
""" """
if _update_flags is None: if _update_flags is None:
return {} return {}
result = {} result = {}
async with get_internal_lock(): async with get_internal_lock():
for namespace, flags in _update_flags.items(): for namespace, flags in _update_flags.items():
@@ -242,7 +243,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
else: else:
worker_statuses.append(flag) worker_statuses.append(flag)
result[namespace] = worker_statuses result[namespace] = worker_statuses
return result return result

View File

@@ -696,7 +696,10 @@ class LightRAG:
3. Process each chunk for entity and relation extraction 3. Process each chunk for entity and relation extraction
4. Update the document status 4. Update the document status
""" """
from lightrag.kg.shared_storage import get_namespace_data, get_pipeline_status_lock from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
)
# Get pipeline status shared data and lock # Get pipeline status shared data and lock
pipeline_status = await get_namespace_data("pipeline_status") pipeline_status = await get_namespace_data("pipeline_status")

View File

@@ -47,10 +47,11 @@ def main():
# Check and install gunicorn if not present # Check and install gunicorn if not present
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("gunicorn"): if not pm.is_installed("gunicorn"):
print("Installing gunicorn...") print("Installing gunicorn...")
pm.install("gunicorn") pm.install("gunicorn")
# Import Gunicorn's StandaloneApplication # Import Gunicorn's StandaloneApplication
from gunicorn.app.base import BaseApplication from gunicorn.app.base import BaseApplication
@@ -103,26 +104,46 @@ def main():
import gunicorn_config import gunicorn_config
# Set configuration variables in gunicorn_config, prioritizing command line arguments # Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = args.workers if args.workers else int(os.getenv("WORKERS", 1)) gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments # Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
gunicorn_config.bind = f"{host}:{port}" gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments # Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments # Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
)
# Keepalive configuration # Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments # SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in ("true", "1", "yes", "t", "on"): if args.ssl or os.getenv("SSL", "").lower() in (
gunicorn_config.certfile = args.ssl_certfile if args.ssl_certfile else os.getenv("SSL_CERTFILE") "true",
gunicorn_config.keyfile = args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") "1",
"yes",
"t",
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module # Set configuration options from the module
for key in dir(gunicorn_config): for key in dir(gunicorn_config):