feat(storage): Add shared memory support for file-based storage implementations
This commit adds multiprocessing shared memory support to file-based storage implementations: - JsonDocStatusStorage - JsonKVStorage - NanoVectorDBStorage - NetworkXStorage Each storage module now uses module-level global variables with multiprocessing.Manager() to ensure data consistency across multiple uvicorn workers. All processes will see updates immediately when data is modified through ainsert function.
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Any, Union, final
|
||||
import threading
|
||||
from multiprocessing import Manager
|
||||
|
||||
from lightrag.base import (
|
||||
DocProcessingStatus,
|
||||
@@ -13,6 +15,25 @@ from lightrag.utils import (
|
||||
write_json,
|
||||
)
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_doc_status_data = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_doc_status_data
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = Manager()
|
||||
_shared_doc_status_data = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -22,8 +43,27 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace data
|
||||
if self.namespace not in _shared_doc_status_data:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_doc_status_data:
|
||||
try:
|
||||
initial_data = load_json(self._file_name) or {}
|
||||
_shared_doc_status_data[self.namespace] = initial_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Shared data initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._data = _shared_doc_status_data[self.namespace]
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
|
@@ -2,6 +2,8 @@ import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import threading
|
||||
from multiprocessing import Manager
|
||||
|
||||
from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
@@ -12,6 +14,25 @@ from lightrag.utils import (
|
||||
write_json,
|
||||
)
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_kv_data = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_kv_data
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = Manager()
|
||||
_shared_kv_data = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -19,9 +40,28 @@ class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace data
|
||||
if self.namespace not in _shared_kv_data:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_kv_data:
|
||||
try:
|
||||
initial_data = load_json(self._file_name) or {}
|
||||
_shared_kv_data[self.namespace] = initial_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Shared data initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._data = _shared_kv_data[self.namespace]
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
|
@@ -3,6 +3,8 @@ import os
|
||||
from typing import Any, final
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
from multiprocessing import Manager
|
||||
|
||||
import time
|
||||
|
||||
@@ -20,6 +22,25 @@ if not pm.is_installed("nano-vectordb"):
|
||||
|
||||
from nano_vectordb import NanoVectorDB
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_vector_clients = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_vector_clients
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = Manager()
|
||||
_shared_vector_clients = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -40,9 +61,29 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace client
|
||||
if self.namespace not in _shared_vector_clients:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_vector_clients:
|
||||
try:
|
||||
client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name
|
||||
)
|
||||
_shared_vector_clients[self.namespace] = client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Vector DB client initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._client = _shared_vector_clients[self.namespace]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import threading
|
||||
from multiprocessing import Manager
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
@@ -24,6 +25,25 @@ if not pm.is_installed("graspologic"):
|
||||
import networkx as nx
|
||||
from graspologic import embed
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_graphs = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_graphs
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = Manager()
|
||||
_shared_graphs = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -78,15 +98,33 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace graph
|
||||
if self.namespace not in _shared_graphs:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_graphs:
|
||||
try:
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
_shared_graphs[self.namespace] = preloaded_graph or nx.Graph()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Graph initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._graph = _shared_graphs[self.namespace]
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
Reference in New Issue
Block a user