From 087d5770b028da1eb844ddfbefaf9b90bd24410e Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 11:10:13 +0800 Subject: [PATCH] feat(storage): Add shared memory support for file-based storage implementations This commit adds multiprocessing shared memory support to file-based storage implementations: - JsonDocStatusStorage - JsonKVStorage - NanoVectorDBStorage - NetworkXStorage Each storage module now uses module-level global variables with multiprocessing.Manager() to ensure data consistency across multiple uvicorn workers. All processes will see updates immediately when data is modified through ainsert function. --- lightrag/kg/json_doc_status_impl.py | 44 +++++++++++++++++++++- lightrag/kg/json_kv_impl.py | 44 +++++++++++++++++++++- lightrag/kg/nano_vector_db_impl.py | 47 +++++++++++++++++++++-- lightrag/kg/networkx_impl.py | 58 ++++++++++++++++++++++++----- 4 files changed, 176 insertions(+), 17 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 63a295cd..431e340c 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -1,6 +1,8 @@ from dataclasses import dataclass import os from typing import Any, Union, final +import threading +from multiprocessing import Manager from lightrag.base import ( DocProcessingStatus, @@ -13,6 +15,25 @@ from lightrag.utils import ( write_json, ) +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_doc_status_data = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_doc_status_data + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_doc_status_data = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -22,8 +43,27 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data: dict[str, Any] = load_json(self._file_name) or {} - logger.info(f"Loaded document status storage with {len(self._data)} records") + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace data + if self.namespace not in _shared_doc_status_data: + with _init_lock: + if self.namespace not in _shared_doc_status_data: + try: + initial_data = load_json(self._file_name) or {} + _shared_doc_status_data[self.namespace] = initial_data + except Exception as e: + logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") + raise RuntimeError(f"Shared data initialization failed: {e}") + + try: + self._data = _shared_doc_status_data[self.namespace] + logger.info(f"Loaded document status storage with {len(self._data)} records") + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index e1ea507a..f03fda63 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -2,6 +2,8 @@ import asyncio import os from dataclasses import dataclass from typing import Any, final +import threading +from multiprocessing import Manager from lightrag.base import ( BaseKVStorage, @@ -12,6 +14,25 @@ from lightrag.utils import ( write_json, ) +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_kv_data = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_kv_data + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_kv_data = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -19,9 +40,28 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data: dict[str, Any] = load_json(self._file_name) or {} self._lock = asyncio.Lock() - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace data + if self.namespace not in _shared_kv_data: + with _init_lock: + if self.namespace not in _shared_kv_data: + try: + initial_data = load_json(self._file_name) or {} + _shared_kv_data[self.namespace] = initial_data + except Exception as e: + logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") + raise RuntimeError(f"Shared data initialization failed: {e}") + + try: + self._data = _shared_kv_data[self.namespace] + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def index_done_callback(self) -> None: write_json(self._data, self._file_name) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index b0900095..d68b7f42 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -3,6 +3,8 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np +import threading +from multiprocessing import Manager import time @@ -20,6 +22,25 @@ if not pm.is_installed("nano-vectordb"): from nano_vectordb import NanoVectorDB +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_vector_clients = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_vector_clients + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_vector_clients = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -40,9 +61,29 @@ class NanoVectorDBStorage(BaseVectorStorage): self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name - ) + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace client + if self.namespace not in _shared_vector_clients: + with _init_lock: + if self.namespace not in _shared_vector_clients: + try: + client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name + ) + _shared_vector_clients[self.namespace] = client + except Exception as e: + logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}") + raise RuntimeError(f"Vector DB client initialization failed: {e}") + + try: + self._client = _shared_vector_clients[self.namespace] + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index b4321458..581a4187 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,10 +1,11 @@ import os from dataclasses import dataclass from typing import Any, final +import threading +from multiprocessing import Manager import numpy as np - from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import ( logger, @@ -24,6 +25,25 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_graphs = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_graphs + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_graphs = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -78,15 +98,33 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - self._graph = preloaded_graph or nx.Graph() - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace graph + if self.namespace not in _shared_graphs: + with _init_lock: + if self.namespace not in _shared_graphs: + try: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + if preloaded_graph is not None: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + _shared_graphs[self.namespace] = preloaded_graph or nx.Graph() + except Exception as e: + logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}") + raise RuntimeError(f"Graph initialization failed: {e}") + + try: + self._graph = _shared_graphs[self.namespace] + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def index_done_callback(self) -> None: NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)