Refactor initialization logic for vector, KV and graph storage implementations

• Add try_initialize_namespace check
• Move init code out of storage locks
• Reduce redundant init conditions
• Simplify initialization flow
• Make init thread-safer
This commit is contained in:
yangdx
2025-02-27 14:55:07 +08:00
parent 03d05b094d
commit f007ebf006
4 changed files with 67 additions and 60 deletions

View File

@@ -15,6 +15,7 @@ from .shared_storage import (
get_storage_lock, get_storage_lock,
get_namespace_object, get_namespace_object,
is_multiprocess, is_multiprocess,
try_initialize_namespace,
) )
if not pm.is_installed("faiss"): if not pm.is_installed("faiss"):
@@ -52,12 +53,13 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_object/get_namespace_data
need_init = try_initialize_namespace("faiss_indices")
self._index = get_namespace_object("faiss_indices") self._index = get_namespace_object("faiss_indices")
self._id_to_meta = get_namespace_data("faiss_meta") self._id_to_meta = get_namespace_data("faiss_meta")
with self._storage_lock: if need_init:
if is_multiprocess: if is_multiprocess:
if self._index.value is None:
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes. # If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP. # For demonstration, we use a simple IndexFlatIP.
@@ -68,7 +70,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Attempt to load an existing index + metadata from disk # Attempt to load an existing index + metadata from disk
self._load_faiss_index() self._load_faiss_index()
else: else:
if self._index is None:
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta.update({}) self._id_to_meta.update({})
self._load_faiss_index() self._load_faiss_index()

View File

@@ -10,7 +10,7 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from .shared_storage import get_namespace_data, get_storage_lock from .shared_storage import get_namespace_data, get_storage_lock, try_initialize_namespace
@final @final
@@ -20,11 +20,15 @@ class JsonKVStorage(BaseKVStorage):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = get_namespace_data(self.namespace) self._data = get_namespace_data(self.namespace)
if need_init:
loaded_data = load_json(self._file_name) or {}
with self._storage_lock: with self._storage_lock:
if not self._data: self._data.update(loaded_data)
self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏 # 文件写入需要加锁,防止多个进程同时写入导致文件损坏

View File

@@ -11,7 +11,7 @@ from lightrag.utils import (
) )
import pipmaster as pm import pipmaster as pm
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
if not pm.is_installed("nano-vectordb"): if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb") pm.install("nano-vectordb")
@@ -40,11 +40,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# check need_init must before get_namespace_object
need_init = try_initialize_namespace(self.namespace)
self._client = get_namespace_object(self.namespace) self._client = get_namespace_object(self.namespace)
with self._storage_lock: if need_init:
if is_multiprocess: if is_multiprocess:
if self._client.value is None:
self._client.value = NanoVectorDB( self._client.value = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
storage_file=self._client_file_name, storage_file=self._client_file_name,
@@ -53,7 +54,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"Initialized vector DB client for namespace {self.namespace}" f"Initialized vector DB client for namespace {self.namespace}"
) )
else: else:
if self._client is None:
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, self.embedding_func.embedding_dim,
storage_file=self._client_file_name, storage_file=self._client_file_name,

View File

@@ -6,7 +6,7 @@ import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseGraphStorage from lightrag.base import BaseGraphStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
import pipmaster as pm import pipmaster as pm
@@ -74,10 +74,13 @@ class NetworkXStorage(BaseGraphStorage):
self.global_config["working_dir"], f"graph_{self.namespace}.graphml" self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
) )
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_object
need_init = try_initialize_namespace(self.namespace)
self._graph = get_namespace_object(self.namespace) self._graph = get_namespace_object(self.namespace)
with self._storage_lock:
if need_init:
if is_multiprocess: if is_multiprocess:
if self._graph.value is None:
preloaded_graph = NetworkXStorage.load_nx_graph( preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file self._graphml_xml_file
) )
@@ -89,7 +92,6 @@ class NetworkXStorage(BaseGraphStorage):
else: else:
logger.info("Created new empty graph") logger.info("Created new empty graph")
else: else:
if self._graph is None:
preloaded_graph = NetworkXStorage.load_nx_graph( preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file self._graphml_xml_file
) )