Refactor initialization logic for vector, KV and graph storage implementations

• Add try_initialize_namespace check
• Move init code out of storage locks
• Reduce redundant init conditions
• Simplify initialization flow
• Make init thread-safer
This commit is contained in:
yangdx
2025-02-27 14:55:07 +08:00
parent 03d05b094d
commit f007ebf006
4 changed files with 67 additions and 60 deletions

View File

@@ -15,6 +15,7 @@ from .shared_storage import (
get_storage_lock,
get_namespace_object,
is_multiprocess,
try_initialize_namespace,
)
if not pm.is_installed("faiss"):
@@ -52,26 +53,26 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_object/get_namespace_data
need_init = try_initialize_namespace("faiss_indices")
self._index = get_namespace_object("faiss_indices")
self._id_to_meta = get_namespace_data("faiss_meta")
with self._storage_lock:
if need_init:
if is_multiprocess:
if self._index.value is None:
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index.value = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta.update({})
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index.value = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta.update({})
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
else:
if self._index is None:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta.update({})
self._load_faiss_index()
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta.update({})
self._load_faiss_index()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""

View File

@@ -10,7 +10,7 @@ from lightrag.utils import (
logger,
write_json,
)
from .shared_storage import get_namespace_data, get_storage_lock
from .shared_storage import get_namespace_data, get_storage_lock, try_initialize_namespace
@final
@@ -20,11 +20,15 @@ class JsonKVStorage(BaseKVStorage):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = get_namespace_data(self.namespace)
with self._storage_lock:
if not self._data:
self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
if need_init:
loaded_data = load_json(self._file_name) or {}
with self._storage_lock:
self._data.update(loaded_data)
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
async def index_done_callback(self) -> None:
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏

View File

@@ -11,7 +11,7 @@ from lightrag.utils import (
)
import pipmaster as pm
from lightrag.base import BaseVectorStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
@@ -40,27 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
self._max_batch_size = self.global_config["embedding_batch_num"]
# check need_init must before get_namespace_object
need_init = try_initialize_namespace(self.namespace)
self._client = get_namespace_object(self.namespace)
with self._storage_lock:
if need_init:
if is_multiprocess:
if self._client.value is None:
self._client.value = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
)
self._client.value = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
)
else:
if self._client is None:
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
)
def _get_client(self):
"""Get the appropriate client instance based on multiprocess mode"""

View File

@@ -6,7 +6,7 @@ import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
import pipmaster as pm
@@ -74,32 +74,34 @@ class NetworkXStorage(BaseGraphStorage):
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_object
need_init = try_initialize_namespace(self.namespace)
self._graph = get_namespace_object(self.namespace)
with self._storage_lock:
if need_init:
if is_multiprocess:
if self._graph.value is None:
preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
)
self._graph.value = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph.value = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
else:
logger.info("Created new empty graph")
else:
if self._graph is None:
preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
preloaded_graph = NetworkXStorage.load_nx_graph(
self._graphml_xml_file
)
self._graph = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
else:
logger.info("Created new empty graph")
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,