feat(storage): Add shared memory support for file-based storage implementations

This commit adds multiprocessing shared memory support to file-based storage implementations:
- JsonDocStatusStorage
- JsonKVStorage
- NanoVectorDBStorage
- NetworkXStorage

Each storage module now uses module-level global variables with multiprocessing.Manager() to ensure data consistency across multiple uvicorn workers. All processes will see
updates immediately when data is modified through ainsert function.
This commit is contained in:
yangdx
2025-02-25 11:10:13 +08:00
parent 7262f61b0e
commit 087d5770b0
4 changed files with 176 additions and 17 deletions

View File

@@ -1,6 +1,8 @@
from dataclasses import dataclass
import os
from typing import Any, Union, final
import threading
from multiprocessing import Manager
from lightrag.base import (
DocProcessingStatus,
@@ -13,6 +15,25 @@ from lightrag.utils import (
write_json,
)
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_doc_status_data = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_doc_status_data
with _init_lock:
if _manager is None:
try:
_manager = Manager()
_shared_doc_status_data = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -22,8 +43,27 @@ class JsonDocStatusStorage(DocStatusStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
# Ensure manager is initialized
_get_manager()
# Get or create namespace data
if self.namespace not in _shared_doc_status_data:
with _init_lock:
if self.namespace not in _shared_doc_status_data:
try:
initial_data = load_json(self._file_name) or {}
_shared_doc_status_data[self.namespace] = initial_data
except Exception as e:
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
raise RuntimeError(f"Shared data initialization failed: {e}")
try:
self._data = _shared_doc_status_data[self.namespace]
logger.info(f"Loaded document status storage with {len(self._data)} records")
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""

View File

@@ -2,6 +2,8 @@ import asyncio
import os
from dataclasses import dataclass
from typing import Any, final
import threading
from multiprocessing import Manager
from lightrag.base import (
BaseKVStorage,
@@ -12,6 +14,25 @@ from lightrag.utils import (
write_json,
)
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_kv_data = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_kv_data
with _init_lock:
if _manager is None:
try:
_manager = Manager()
_shared_kv_data = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -19,9 +40,28 @@ class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {}
self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
# Ensure manager is initialized
_get_manager()
# Get or create namespace data
if self.namespace not in _shared_kv_data:
with _init_lock:
if self.namespace not in _shared_kv_data:
try:
initial_data = load_json(self._file_name) or {}
_shared_kv_data[self.namespace] = initial_data
except Exception as e:
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
raise RuntimeError(f"Shared data initialization failed: {e}")
try:
self._data = _shared_kv_data[self.namespace]
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def index_done_callback(self) -> None:
write_json(self._data, self._file_name)

View File

@@ -3,6 +3,8 @@ import os
from typing import Any, final
from dataclasses import dataclass
import numpy as np
import threading
from multiprocessing import Manager
import time
@@ -20,6 +22,25 @@ if not pm.is_installed("nano-vectordb"):
from nano_vectordb import NanoVectorDB
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_vector_clients = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_vector_clients
with _init_lock:
if _manager is None:
try:
_manager = Manager()
_shared_vector_clients = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -40,9 +61,29 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
# Ensure manager is initialized
_get_manager()
# Get or create namespace client
if self.namespace not in _shared_vector_clients:
with _init_lock:
if self.namespace not in _shared_vector_clients:
try:
client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name
)
_shared_vector_clients[self.namespace] = client
except Exception as e:
logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
raise RuntimeError(f"Vector DB client initialization failed: {e}")
try:
self._client = _shared_vector_clients[self.namespace]
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")

View File

@@ -1,10 +1,11 @@
import os
from dataclasses import dataclass
from typing import Any, final
import threading
from multiprocessing import Manager
import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import (
logger,
@@ -24,6 +25,25 @@ if not pm.is_installed("graspologic"):
import networkx as nx
from graspologic import embed
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_graphs = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_graphs
with _init_lock:
if _manager is None:
try:
_manager = Manager()
_shared_graphs = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -78,15 +98,33 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
# Ensure manager is initialized
_get_manager()
# Get or create namespace graph
if self.namespace not in _shared_graphs:
with _init_lock:
if self.namespace not in _shared_graphs:
try:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
_shared_graphs[self.namespace] = preloaded_graph or nx.Graph()
except Exception as e:
logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}")
raise RuntimeError(f"Graph initialization failed: {e}")
try:
self._graph = _shared_graphs[self.namespace]
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)