Refactor initialization logic for vector, KV and graph storage implementations

• Add try_initialize_namespace check
• Move init code out of storage locks
• Reduce redundant init conditions
• Simplify initialization flow
• Make init thread-safer
This commit is contained in:
yangdx
2025-02-27 14:55:07 +08:00
parent 03d05b094d
commit f007ebf006
4 changed files with 67 additions and 60 deletions

View File

@@ -15,6 +15,7 @@ from .shared_storage import (
get_storage_lock, get_storage_lock,
get_namespace_object, get_namespace_object,
is_multiprocess, is_multiprocess,
try_initialize_namespace,
) )
if not pm.is_installed("faiss"): if not pm.is_installed("faiss"):
@@ -52,26 +53,26 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_object/get_namespace_data
need_init = try_initialize_namespace("faiss_indices")
self._index = get_namespace_object("faiss_indices") self._index = get_namespace_object("faiss_indices")
self._id_to_meta = get_namespace_data("faiss_meta") self._id_to_meta = get_namespace_data("faiss_meta")
with self._storage_lock: if need_init:
if is_multiprocess: if is_multiprocess:
if self._index.value is None: # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # If you have a large number of vectors, you might want IVF or other indexes.
# If you have a large number of vectors, you might want IVF or other indexes. # For demonstration, we use a simple IndexFlatIP.
# For demonstration, we use a simple IndexFlatIP. self._index.value = faiss.IndexFlatIP(self._dim)
self._index.value = faiss.IndexFlatIP(self._dim) # Keep a local store for metadata, IDs, etc.
# Keep a local store for metadata, IDs, etc. # Maps <int faiss_id> → metadata (including your original ID).
# Maps <int faiss_id> → metadata (including your original ID). self._id_to_meta.update({})
self._id_to_meta.update({}) # Attempt to load an existing index + metadata from disk
# Attempt to load an existing index + metadata from disk self._load_faiss_index()
self._load_faiss_index()
else: else:
if self._index is None: self._index = faiss.IndexFlatIP(self._dim)
self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta.update({})
self._id_to_meta.update({}) self._load_faiss_index()
self._load_faiss_index()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """

View File

@@ -10,7 +10,7 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from .shared_storage import get_namespace_data, get_storage_lock from .shared_storage import get_namespace_data, get_storage_lock, try_initialize_namespace
@final @final
@@ -20,11 +20,15 @@ class JsonKVStorage(BaseKVStorage):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = get_namespace_data(self.namespace) self._data = get_namespace_data(self.namespace)
with self._storage_lock: if need_init:
if not self._data: loaded_data = load_json(self._file_name) or {}
self._data: dict[str, Any] = load_json(self._file_name) or {} with self._storage_lock:
logger.info(f"Load KV {self.namespace} with {len(self._data)} data") self._data.update(loaded_data)
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏 # 文件写入需要加锁,防止多个进程同时写入导致文件损坏

View File

@@ -11,7 +11,7 @@ from lightrag.utils import (
) )
import pipmaster as pm import pipmaster as pm
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
if not pm.is_installed("nano-vectordb"): if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb") pm.install("nano-vectordb")
@@ -40,27 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# check need_init must before get_namespace_object
need_init = try_initialize_namespace(self.namespace)
self._client = get_namespace_object(self.namespace) self._client = get_namespace_object(self.namespace)
with self._storage_lock: if need_init:
if is_multiprocess: if is_multiprocess:
if self._client.value is None: self._client.value = NanoVectorDB(
self._client.value = NanoVectorDB( self.embedding_func.embedding_dim,
self.embedding_func.embedding_dim, storage_file=self._client_file_name,
storage_file=self._client_file_name, )
) logger.info(
logger.info( f"Initialized vector DB client for namespace {self.namespace}"
f"Initialized vector DB client for namespace {self.namespace}" )
)
else: else:
if self._client is None: self._client = NanoVectorDB(
self._client = NanoVectorDB( self.embedding_func.embedding_dim,
self.embedding_func.embedding_dim, storage_file=self._client_file_name,
storage_file=self._client_file_name, )
) logger.info(
logger.info( f"Initialized vector DB client for namespace {self.namespace}"
f"Initialized vector DB client for namespace {self.namespace}" )
)
def _get_client(self): def _get_client(self):
"""Get the appropriate client instance based on multiprocess mode""" """Get the appropriate client instance based on multiprocess mode"""

View File

@@ -6,7 +6,7 @@ import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseGraphStorage from lightrag.base import BaseGraphStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
import pipmaster as pm import pipmaster as pm
@@ -74,32 +74,34 @@ class NetworkXStorage(BaseGraphStorage):
self.global_config["working_dir"], f"graph_{self.namespace}.graphml" self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
) )
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_object
need_init = try_initialize_namespace(self.namespace)
self._graph = get_namespace_object(self.namespace) self._graph = get_namespace_object(self.namespace)
with self._storage_lock:
if need_init:
if is_multiprocess: if is_multiprocess:
if self._graph.value is None: preloaded_graph = NetworkXStorage.load_nx_graph(
preloaded_graph = NetworkXStorage.load_nx_graph( self._graphml_xml_file
self._graphml_xml_file )
self._graph.value = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
) )
self._graph.value = preloaded_graph or nx.Graph() else:
if preloaded_graph: logger.info("Created new empty graph")
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
else: else:
if self._graph is None: preloaded_graph = NetworkXStorage.load_nx_graph(
preloaded_graph = NetworkXStorage.load_nx_graph( self._graphml_xml_file
self._graphml_xml_file )
self._graph = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
) )
self._graph = preloaded_graph or nx.Graph() else:
if preloaded_graph: logger.info("Created new empty graph")
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,