diff --git a/.env.example b/.env.example index 0f8e6c31..8a14cdb3 100644 --- a/.env.example +++ b/.env.example @@ -141,4 +141,4 @@ QDRANT_URL=http://localhost:16333 # QDRANT_API_KEY=your-api-key ### Redis -REDIS_URI=redis://localhost:6379 \ No newline at end of file +REDIS_URI=redis://localhost:6379 diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 07108c52..270bbb24 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -54,11 +54,12 @@ config.read("config.ini") class LightragPathFilter(logging.Filter): """Filter for lightrag logger to filter out frequent path access logs""" + def __init__(self): super().__init__() # Define paths to be filtered self.filtered_paths = ["/documents", "/health", "/webui/"] - + def filter(self, record): try: # Check if record has the required attributes for an access log @@ -90,11 +91,13 @@ def create_app(args): # Initialize verbose debug setting # Can not use the logger at the top of this module when workers > 1 from lightrag.utils import set_verbose_debug, logger + # Setup logging logger.setLevel(getattr(logging, args.log_level)) set_verbose_debug(args.verbose) from lightrag.kg.shared_storage import is_multiprocess + logger.info(f"==== Multi-processor mode: {is_multiprocess} ====") # Verify that bindings are correctly setup @@ -147,9 +150,7 @@ def create_app(args): # Auto scan documents if enabled if args.auto_scan_at_startup: # Create background task - task = asyncio.create_task( - run_scanning_process(rag, doc_manager) - ) + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) app.state.background_tasks.add(task) task.add_done_callback(app.state.background_tasks.discard) @@ -411,17 +412,19 @@ def get_application(): """Factory function for creating the FastAPI application""" # Configure logging for this worker process configure_logging() - + # Get args from environment variable - args_json = os.environ.get('LIGHTRAG_ARGS') + args_json = os.environ.get("LIGHTRAG_ARGS") if not args_json: args = parse_args() # Fallback to parsing args if env var not set else: import types + args = types.SimpleNamespace(**json.loads(args_json)) - + if args.workers > 1: from lightrag.kg.shared_storage import initialize_share_data + initialize_share_data() return create_app(args) @@ -434,58 +437,61 @@ def configure_logging(): logger = logging.getLogger(logger_name) logger.handlers = [] logger.filters = [] - + # Configure basic logging - logging.config.dictConfig({ - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "%(levelname)s: %(message)s", + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, }, - }, - "handlers": { - "default": { - "formatter": "default", - "class": "logging.StreamHandler", - "stream": "ext://sys.stderr", + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, }, - }, - "loggers": { - "uvicorn.access": { - "handlers": ["default"], - "level": "INFO", - "propagate": False, - "filters": ["path_filter"], + "loggers": { + "uvicorn.access": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + "filters": ["path_filter"], + }, + "lightrag": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + "filters": ["path_filter"], + }, }, - "lightrag": { - "handlers": ["default"], - "level": "INFO", - "propagate": False, - "filters": ["path_filter"], + "filters": { + "path_filter": { + "()": "lightrag.api.lightrag_server.LightragPathFilter", + }, }, - }, - "filters": { - "path_filter": { - "()": "lightrag.api.lightrag_server.LightragPathFilter", - }, - }, - }) + } + ) + def main(): from multiprocessing import freeze_support + freeze_support() - + args = parse_args() # Save args to environment variable for child processes - os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) + os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args)) # Configure logging before starting uvicorn configure_logging() display_splash_screen(args) - uvicorn_config = { "app": "lightrag.api.lightrag_server:get_application", "factory": True, diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 50bc39df..1f591750 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -375,62 +375,70 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): - """Background task to scan and index documents""" + """Background task to scan and index documents""" scan_progress = get_scan_progress() scan_lock = get_scan_lock() - + # Initialize scan_progress if not already initialized if not scan_progress: - scan_progress.update({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - + scan_progress.update( + { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) + with scan_lock: if scan_progress.get("is_scanning", False): - ASCIIColors.info( - "Skip document scanning(another scanning is active)" - ) + ASCIIColors.info("Skip document scanning(another scanning is active)") return - scan_progress.update({ - "is_scanning": True, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) + scan_progress.update( + { + "is_scanning": True, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) try: new_files = doc_manager.scan_directory_for_new_files() total_files = len(new_files) - scan_progress.update({ - "current_file": "", - "total_files": total_files, - "indexed_count": 0, - "progress": 0, - }) + scan_progress.update( + { + "current_file": "", + "total_files": total_files, + "indexed_count": 0, + "progress": 0, + } + ) logging.info(f"Found {total_files} new files to index.") for idx, file_path in enumerate(new_files): try: progress = (idx / total_files * 100) if total_files > 0 else 0 - scan_progress.update({ - "current_file": os.path.basename(file_path), - "indexed_count": idx, - "progress": progress, - }) - + scan_progress.update( + { + "current_file": os.path.basename(file_path), + "indexed_count": idx, + "progress": progress, + } + ) + await pipeline_index_file(rag, file_path) - + progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0 - scan_progress.update({ - "current_file": os.path.basename(file_path), - "indexed_count": idx + 1, - "progress": progress, - }) + scan_progress.update( + { + "current_file": os.path.basename(file_path), + "indexed_count": idx + 1, + "progress": progress, + } + ) except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") @@ -438,13 +446,15 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): except Exception as e: logging.error(f"Error during scanning process: {str(e)}") finally: - scan_progress.update({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) + scan_progress.update( + { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) def create_document_routes( diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 6b501e64..c494101c 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -433,7 +433,6 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.white(" └─ Document Status Storage: ", end="") ASCIIColors.yellow(f"{args.doc_status_storage}") - # Server Status ASCIIColors.green("\n✨ Server starting up...\n") diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 3e59d171..a9d058f4 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -8,14 +8,19 @@ import numpy as np from dataclasses import dataclass import pipmaster as pm -from lightrag.utils import logger,compute_mdhash_id +from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseVectorStorage -from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess +from .shared_storage import ( + get_namespace_data, + get_storage_lock, + get_namespace_object, + is_multiprocess, +) if not pm.is_installed("faiss"): pm.install("faiss") -import faiss # type: ignore +import faiss # type: ignore @final @@ -46,10 +51,10 @@ class FaissVectorDBStorage(BaseVectorStorage): # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim self._storage_lock = get_storage_lock() - - self._index = get_namespace_object('faiss_indices') - self._id_to_meta = get_namespace_data('faiss_meta') - + + self._index = get_namespace_object("faiss_indices") + self._id_to_meta = get_namespace_data("faiss_meta") + with self._storage_lock: if is_multiprocess: if self._index.value is None: @@ -68,7 +73,6 @@ class FaissVectorDBStorage(BaseVectorStorage): self._id_to_meta.update({}) self._load_faiss_index() - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -168,7 +172,9 @@ class FaissVectorDBStorage(BaseVectorStorage): # Perform the similarity search with self._storage_lock: - distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k) + distances, indices = ( + self._index.value if is_multiprocess else self._index + ).search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -232,7 +238,10 @@ class FaissVectorDBStorage(BaseVectorStorage): with self._storage_lock: relations = [] for fid, meta in self._id_to_meta.items(): - if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: + if ( + meta.get("src_id") == entity_name + or meta.get("tgt_id") == entity_name + ): relations.append(fid) logger.debug(f"Found {len(relations)} relations for {entity_name}") @@ -292,7 +301,10 @@ class FaissVectorDBStorage(BaseVectorStorage): Save the current Faiss index + metadata to disk so it can persist across runs. """ with self._storage_lock: - faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file) + faiss.write_index( + self._index.value if is_multiprocess else self._index, + self._faiss_index_file, + ) # Save metadata dict to JSON. Convert all keys to strings for JSON storage. # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } @@ -320,7 +332,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._index.value = loaded_index else: self._index = loaded_index - + # Load metadata with open(self._meta_file, "r", encoding="utf-8") as f: stored_dict = json.load(f) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index ee5d8a07..4c80854a 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -26,7 +26,6 @@ class JsonKVStorage(BaseKVStorage): self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - async def index_done_callback(self) -> None: # 文件写入需要加锁,防止多个进程同时写入导致文件损坏 with self._storage_lock: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index d1682c7a..7707a0f0 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -25,7 +25,7 @@ class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize lock only for file operations self._storage_lock = get_storage_lock() - + # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -39,22 +39,28 @@ class NanoVectorDBStorage(BaseVectorStorage): self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] - + self._client = get_namespace_object(self.namespace) - + with self._storage_lock: if is_multiprocess: if self._client.value is None: self._client.value = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" ) - logger.info(f"Initialized vector DB client for namespace {self.namespace}") else: if self._client is None: self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" ) - logger.info(f"Initialized vector DB client for namespace {self.namespace}") def _get_client(self): """Get the appropriate client instance based on multiprocess mode""" @@ -104,7 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Execute embedding outside of lock to avoid long lock times embedding = await self.embedding_func([query]) embedding = embedding[0] - + with self._storage_lock: client = self._get_client() results = client.query( @@ -150,7 +156,7 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.debug( f"Attempting to delete entity {entity_name} with ID {entity_id}" ) - + with self._storage_lock: client = self._get_client() # Check if the entity exists @@ -172,7 +178,9 @@ class NanoVectorDBStorage(BaseVectorStorage): for dp in storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name ] - logger.debug(f"Found {len(relations)} relations for entity {entity_name}") + logger.debug( + f"Found {len(relations)} relations for entity {entity_name}" + ) ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 74a6ee28..07bd9666 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -78,29 +78,33 @@ class NetworkXStorage(BaseGraphStorage): with self._storage_lock: if is_multiprocess: if self._graph.value is None: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) self._graph.value = preloaded_graph or nx.Graph() if preloaded_graph: logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) else: logger.info("Created new empty graph") else: if self._graph is None: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) self._graph = preloaded_graph or nx.Graph() if preloaded_graph: logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) else: logger.info("Created new empty graph") self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, + "node2vec": self._node2vec_embed, } - + def _get_graph(self): """Get the appropriate graph instance based on multiprocess mode""" if is_multiprocess: @@ -248,11 +252,13 @@ class NetworkXStorage(BaseGraphStorage): with self._storage_lock: graph = self._get_graph() - + # Handle special case for "*" label if node_label == "*": # For "*", return the entire graph including all nodes and edges - subgraph = graph.copy() # Create a copy to avoid modifying the original graph + subgraph = ( + graph.copy() + ) # Create a copy to avoid modifying the original graph else: # Find nodes with matching node id (partial match) nodes_to_explore = [] @@ -272,9 +278,9 @@ class NetworkXStorage(BaseGraphStorage): if len(subgraph.nodes()) > max_graph_nodes: origin_nodes = len(subgraph.nodes()) node_degrees = dict(subgraph.degree()) - top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ - :max_graph_nodes - ] + top_nodes = sorted( + node_degrees.items(), key=lambda x: x[1], reverse=True + )[:max_graph_nodes] top_node_ids = [node[0] for node in top_nodes] # Create new subgraph with only top nodes subgraph = subgraph.subgraph(top_node_ids) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 27aca9d0..bd4c55fe 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -17,106 +17,125 @@ _shared_dicts: Optional[Dict[str, Any]] = {} _share_objects: Optional[Dict[str, Any]] = {} _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized + def initialize_share_data(): """Initialize shared data, only called if multiple processes where workers > 1""" global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess is_multiprocess = True - + logger.info(f"Process {os.getpid()} initializing shared storage") - + # Initialize manager if _manager is None: _manager = Manager() logger.info(f"Process {os.getpid()} created manager") - + # Create shared dictionaries with manager _shared_dicts = _manager.dict() _share_objects = _manager.dict() _init_flags = _manager.dict() # 使用共享字典存储初始化标志 logger.info(f"Process {os.getpid()} created shared dictionaries") + def try_initialize_namespace(namespace: str) -> bool: """ 尝试初始化命名空间。返回True表示当前进程获得了初始化权限。 使用共享字典的原子操作确保只有一个进程能成功初始化。 """ global _init_flags, _manager - + if is_multiprocess: if _init_flags is None: - raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.") + raise RuntimeError( + "Shared storage not initialized. Call initialize_share_data() first." + ) else: if _init_flags is None: _init_flags = {} - + logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") - + # 使用全局锁保护共享字典的访问 with _get_global_lock(): # 检查是否已经初始化 if namespace not in _init_flags: # 设置初始化标志 _init_flags[namespace] = True - logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}") + logger.info( + f"Process {os.getpid()} ready to initialize namespace {namespace}" + ) return True - - logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized") + + logger.info( + f"Process {os.getpid()} found namespace {namespace} already initialized" + ) return False + def _get_global_lock() -> LockType: global _global_lock, is_multiprocess, _manager - + if _global_lock is None: if is_multiprocess: _global_lock = _manager.Lock() # Use manager for lock else: _global_lock = ThreadLock() - + return _global_lock + def get_storage_lock() -> LockType: """return storage lock for data consistency""" return _get_global_lock() + def get_scan_lock() -> LockType: """return scan_progress lock for data consistency""" return get_storage_lock() + def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" global _share_objects, is_multiprocess, _manager - + if is_multiprocess and not _manager: - raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + raise RuntimeError( + "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." + ) if namespace not in _share_objects: lock = _get_global_lock() with lock: if namespace not in _share_objects: if is_multiprocess: - _share_objects[namespace] = _manager.Value('O', None) + _share_objects[namespace] = _manager.Value("O", None) else: _share_objects[namespace] = None - + return _share_objects[namespace] + # 移除不再使用的函数 + def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" global _shared_dicts, is_multiprocess, _manager - + if is_multiprocess and not _manager: - raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + raise RuntimeError( + "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." + ) if namespace not in _shared_dicts: lock = _get_global_lock() with lock: if namespace not in _shared_dicts: _shared_dicts[namespace] = {} - + return _shared_dicts[namespace] + def get_scan_progress() -> Dict[str, Any]: """get storage space for document scanning progress data""" - return get_namespace_data('scan_progress') + return get_namespace_data("scan_progress") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d7da6017..46638243 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -266,7 +266,7 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) - def __post_init__(self): + def __post_init__(self): os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") diff --git a/lightrag/utils.py b/lightrag/utils.py index bc78e2cb..a6265048 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -55,6 +55,7 @@ def set_verbose_debug(enabled: bool): global VERBOSE_DEBUG VERBOSE_DEBUG = enabled + statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} # Initialize logger @@ -100,6 +101,7 @@ class UnlimitedSemaphore: ENCODER = None + @dataclass class EmbeddingFunc: embedding_dim: int