From f007ebf006815a1854595aef665208c928626bc0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 14:55:07 +0800 Subject: [PATCH] Refactor initialization logic for vector, KV and graph storage implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add try_initialize_namespace check • Move init code out of storage locks • Reduce redundant init conditions • Simplify initialization flow • Make init thread-safer --- lightrag/kg/faiss_impl.py | 31 ++++++++++---------- lightrag/kg/json_kv_impl.py | 14 +++++---- lightrag/kg/nano_vector_db_impl.py | 36 +++++++++++------------ lightrag/kg/networkx_impl.py | 46 ++++++++++++++++-------------- 4 files changed, 67 insertions(+), 60 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index a9d058f4..0315de7c 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -15,6 +15,7 @@ from .shared_storage import ( get_storage_lock, get_namespace_object, is_multiprocess, + try_initialize_namespace, ) if not pm.is_installed("faiss"): @@ -52,26 +53,26 @@ class FaissVectorDBStorage(BaseVectorStorage): self._dim = self.embedding_func.embedding_dim self._storage_lock = get_storage_lock() + # check need_init must before get_namespace_object/get_namespace_data + need_init = try_initialize_namespace("faiss_indices") self._index = get_namespace_object("faiss_indices") self._id_to_meta = get_namespace_data("faiss_meta") - with self._storage_lock: + if need_init: if is_multiprocess: - if self._index.value is None: - # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). - # If you have a large number of vectors, you might want IVF or other indexes. - # For demonstration, we use a simple IndexFlatIP. - self._index.value = faiss.IndexFlatIP(self._dim) - # Keep a local store for metadata, IDs, etc. - # Maps → metadata (including your original ID). - self._id_to_meta.update({}) - # Attempt to load an existing index + metadata from disk - self._load_faiss_index() + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index.value = faiss.IndexFlatIP(self._dim) + # Keep a local store for metadata, IDs, etc. + # Maps → metadata (including your original ID). + self._id_to_meta.update({}) + # Attempt to load an existing index + metadata from disk + self._load_faiss_index() else: - if self._index is None: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta.update({}) - self._load_faiss_index() + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta.update({}) + self._load_faiss_index() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 4c80854a..f13cdfb6 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -10,7 +10,7 @@ from lightrag.utils import ( logger, write_json, ) -from .shared_storage import get_namespace_data, get_storage_lock +from .shared_storage import get_namespace_data, get_storage_lock, try_initialize_namespace @final @@ -20,11 +20,15 @@ class JsonKVStorage(BaseKVStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + + # check need_init must before get_namespace_data + need_init = try_initialize_namespace(self.namespace) self._data = get_namespace_data(self.namespace) - with self._storage_lock: - if not self._data: - self._data: dict[str, Any] = load_json(self._file_name) or {} - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + if need_init: + loaded_data = load_json(self._file_name) or {} + with self._storage_lock: + self._data.update(loaded_data) + logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") async def index_done_callback(self) -> None: # 文件写入需要加锁,防止多个进程同时写入导致文件损坏 diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 7707a0f0..64b0e720 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -11,7 +11,7 @@ from lightrag.utils import ( ) import pipmaster as pm from lightrag.base import BaseVectorStorage -from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") @@ -40,27 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] + # check need_init must before get_namespace_object + need_init = try_initialize_namespace(self.namespace) self._client = get_namespace_object(self.namespace) - with self._storage_lock: + if need_init: if is_multiprocess: - if self._client.value is None: - self._client.value = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) + self._client.value = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" + ) else: - if self._client is None: - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" + ) def _get_client(self): """Get the appropriate client instance based on multiprocess mode""" diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 07bd9666..aec49e6c 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -6,7 +6,7 @@ import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage -from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace import pipmaster as pm @@ -74,32 +74,34 @@ class NetworkXStorage(BaseGraphStorage): self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) self._storage_lock = get_storage_lock() + + # check need_init must before get_namespace_object + need_init = try_initialize_namespace(self.namespace) self._graph = get_namespace_object(self.namespace) - with self._storage_lock: + + if need_init: if is_multiprocess: - if self._graph.value is None: - preloaded_graph = NetworkXStorage.load_nx_graph( - self._graphml_xml_file + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) + self._graph.value = preloaded_graph or nx.Graph() + if preloaded_graph: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) - self._graph.value = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - else: - logger.info("Created new empty graph") + else: + logger.info("Created new empty graph") else: - if self._graph is None: - preloaded_graph = NetworkXStorage.load_nx_graph( - self._graphml_xml_file + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) + self._graph = preloaded_graph or nx.Graph() + if preloaded_graph: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) - self._graph = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - else: - logger.info("Created new empty graph") + else: + logger.info("Created new empty graph") self._node_embed_algorithms = { "node2vec": self._node2vec_embed,