diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index b6b998e4..a3520653 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -10,19 +10,12 @@ import pipmaster as pm from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseVectorStorage -from .shared_storage import ( - get_namespace_data, - get_storage_lock, - get_namespace_object, - is_multiprocess, - try_initialize_namespace, -) if not pm.is_installed("faiss"): pm.install("faiss") import faiss # type: ignore - +from threading import Lock as ThreadLock @final @dataclass @@ -51,35 +44,29 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - self._storage_lock = get_storage_lock() + self._storage_lock = ThreadLock() - # check need_init must before get_namespace_object/get_namespace_data - need_init = try_initialize_namespace("faiss_indices") - self._index = get_namespace_object("faiss_indices") - self._id_to_meta = get_namespace_data("faiss_meta") + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index = faiss.IndexFlatIP(self._dim) + + # Keep a local store for metadata, IDs, etc. + # Maps → metadata (including your original ID). + self._id_to_meta = {} + + # Attempt to load an existing index + metadata from disk + with self._storage_lock: + self._load_faiss_index() - if need_init: - if is_multiprocess: - # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). - # If you have a large number of vectors, you might want IVF or other indexes. - # For demonstration, we use a simple IndexFlatIP. - self._index.value = faiss.IndexFlatIP(self._dim) - # Keep a local store for metadata, IDs, etc. - # Maps → metadata (including your original ID). - self._id_to_meta.update({}) - # Attempt to load an existing index + metadata from disk - self._load_faiss_index() - else: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta.update({}) - self._load_faiss_index() def _get_index(self): - """ - Helper method to get the correct index object based on multiprocess mode. - Returns the actual index object that can be used for operations. - """ - return self._index.value if is_multiprocess else self._index + """Check if the shtorage should be reloaded""" + return self._index + + async def index_done_callback(self) -> None: + with self._storage_lock: + self._save_faiss_index() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ @@ -134,34 +121,33 @@ class FaissVectorDBStorage(BaseVectorStorage): # Normalize embeddings for cosine similarity (in-place) faiss.normalize_L2(embeddings) - with self._storage_lock: - # Upsert logic: - # 1. Identify which vectors to remove if they exist - # 2. Remove them - # 3. Add the new vectors - existing_ids_to_remove = [] - for meta, emb in zip(list_data, embeddings): - faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) - if faiss_internal_id is not None: - existing_ids_to_remove.append(faiss_internal_id) + # Upsert logic: + # 1. Identify which vectors to remove if they exist + # 2. Remove them + # 3. Add the new vectors + existing_ids_to_remove = [] + for meta, emb in zip(list_data, embeddings): + faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) + if faiss_internal_id is not None: + existing_ids_to_remove.append(faiss_internal_id) - if existing_ids_to_remove: - self._remove_faiss_ids(existing_ids_to_remove) + if existing_ids_to_remove: + self._remove_faiss_ids(existing_ids_to_remove) - # Step 2: Add new vectors - index = self._get_index() - start_idx = index.ntotal - index.add(embeddings) + # Step 2: Add new vectors + index = self._get_index() + start_idx = index.ntotal + index.add(embeddings) - # Step 3: Store metadata + vector for each new ID - for i, meta in enumerate(list_data): - fid = start_idx + i - # Store the raw vector so we can rebuild if something is removed - meta["__vector__"] = embeddings[i].tolist() - self._id_to_meta.update({fid: meta}) + # Step 3: Store metadata + vector for each new ID + for i, meta in enumerate(list_data): + fid = start_idx + i + # Store the raw vector so we can rebuild if something is removed + meta["__vector__"] = embeddings[i].tolist() + self._id_to_meta.update({fid: meta}) - logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") - return [m["__id__"] for m in list_data] + logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") + return [m["__id__"] for m in list_data] async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """ @@ -177,57 +163,54 @@ class FaissVectorDBStorage(BaseVectorStorage): ) # Perform the similarity search - with self._storage_lock: - distances, indices = self._get_index().search(embedding, top_k) + distances, indices = self._get_index().search(embedding, top_k) - distances = distances[0] - indices = indices[0] + distances = distances[0] + indices = indices[0] - results = [] - for dist, idx in zip(distances, indices): - if idx == -1: - # Faiss returns -1 if no neighbor - continue + results = [] + for dist, idx in zip(distances, indices): + if idx == -1: + # Faiss returns -1 if no neighbor + continue - # Cosine similarity threshold - if dist < self.cosine_better_than_threshold: - continue + # Cosine similarity threshold + if dist < self.cosine_better_than_threshold: + continue - meta = self._id_to_meta.get(idx, {}) - results.append( - { - **meta, - "id": meta.get("__id__"), - "distance": float(dist), - "created_at": meta.get("__created_at__"), - } - ) + meta = self._id_to_meta.get(idx, {}) + results.append( + { + **meta, + "id": meta.get("__id__"), + "distance": float(dist), + "created_at": meta.get("__created_at__"), + } + ) - return results + return results @property def client_storage(self): # Return whatever structure LightRAG might need for debugging - with self._storage_lock: - return {"data": list(self._id_to_meta.values())} + return {"data": list(self._id_to_meta.values())} async def delete(self, ids: list[str]): """ Delete vectors for the provided custom IDs. """ logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") - with self._storage_lock: - to_remove = [] - for cid in ids: - fid = self._find_faiss_id_by_custom_id(cid) - if fid is not None: - to_remove.append(fid) + to_remove = [] + for cid in ids: + fid = self._find_faiss_id_by_custom_id(cid) + if fid is not None: + to_remove.append(fid) - if to_remove: - self._remove_faiss_ids(to_remove) - logger.debug( - f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" - ) + if to_remove: + self._remove_faiss_ids(to_remove) + logger.debug( + f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" + ) async def delete_entity(self, entity_name: str) -> None: entity_id = compute_mdhash_id(entity_name, prefix="ent-") @@ -239,23 +222,18 @@ class FaissVectorDBStorage(BaseVectorStorage): Delete relations for a given entity by scanning metadata. """ logger.debug(f"Searching relations for entity {entity_name}") - with self._storage_lock: - relations = [] - for fid, meta in self._id_to_meta.items(): - if ( - meta.get("src_id") == entity_name - or meta.get("tgt_id") == entity_name - ): - relations.append(fid) + relations = [] + for fid, meta in self._id_to_meta.items(): + if ( + meta.get("src_id") == entity_name + or meta.get("tgt_id") == entity_name + ): + relations.append(fid) - logger.debug(f"Found {len(relations)} relations for {entity_name}") - if relations: - self._remove_faiss_ids(relations) - logger.debug(f"Deleted {len(relations)} relations for {entity_name}") - - async def index_done_callback(self) -> None: - with self._storage_lock: - self._save_faiss_index() + logger.debug(f"Found {len(relations)} relations for {entity_name}") + if relations: + self._remove_faiss_ids(relations) + logger.debug(f"Deleted {len(relations)} relations for {entity_name}") # -------------------------------------------------------------------------------- # Internal helper methods @@ -265,11 +243,10 @@ class FaissVectorDBStorage(BaseVectorStorage): """ Return the Faiss internal ID for a given custom ID, or None if not found. """ - with self._storage_lock: - for fid, meta in self._id_to_meta.items(): - if meta.get("__id__") == custom_id: - return fid - return None + for fid, meta in self._id_to_meta.items(): + if meta.get("__id__") == custom_id: + return fid + return None def _remove_faiss_ids(self, fid_list): """ @@ -277,48 +254,42 @@ class FaissVectorDBStorage(BaseVectorStorage): Because IndexFlatIP doesn't support 'removals', we rebuild the index excluding those vectors. """ + keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] + + # Rebuild the index + vectors_to_keep = [] + new_id_to_meta = {} + for new_fid, old_fid in enumerate(keep_fids): + vec_meta = self._id_to_meta[old_fid] + vectors_to_keep.append(vec_meta["__vector__"]) # stored as list + new_id_to_meta[new_fid] = vec_meta + with self._storage_lock: - keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] - - # Rebuild the index - vectors_to_keep = [] - new_id_to_meta = {} - for new_fid, old_fid in enumerate(keep_fids): - vec_meta = self._id_to_meta[old_fid] - vectors_to_keep.append(vec_meta["__vector__"]) # stored as list - new_id_to_meta[new_fid] = vec_meta - - # Re-init index - new_index = faiss.IndexFlatIP(self._dim) + # Re-init index + self._index = faiss.IndexFlatIP(self._dim) if vectors_to_keep: arr = np.array(vectors_to_keep, dtype=np.float32) - new_index.add(arr) - if is_multiprocess: - self._index.value = new_index - else: - self._index = new_index + self._index.add(arr) + + self._id_to_meta = new_id_to_meta - self._id_to_meta.update(new_id_to_meta) def _save_faiss_index(self): """ Save the current Faiss index + metadata to disk so it can persist across runs. """ - with self._storage_lock: - faiss.write_index( - self._get_index(), - self._faiss_index_file, - ) + faiss.write_index(self._index, self._faiss_index_file) - # Save metadata dict to JSON. Convert all keys to strings for JSON storage. - # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } - # We'll keep the int -> dict, but JSON requires string keys. - serializable_dict = {} - for fid, meta in self._id_to_meta.items(): - serializable_dict[str(fid)] = meta + # Save metadata dict to JSON. Convert all keys to strings for JSON storage. + # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } + # We'll keep the int -> dict, but JSON requires string keys. + serializable_dict = {} + for fid, meta in self._id_to_meta.items(): + serializable_dict[str(fid)] = meta + + with open(self._meta_file, "w", encoding="utf-8") as f: + json.dump(serializable_dict, f) - with open(self._meta_file, "w", encoding="utf-8") as f: - json.dump(serializable_dict, f) def _load_faiss_index(self): """ @@ -331,31 +302,22 @@ class FaissVectorDBStorage(BaseVectorStorage): try: # Load the Faiss index - loaded_index = faiss.read_index(self._faiss_index_file) - if is_multiprocess: - self._index.value = loaded_index - else: - self._index = loaded_index - + self._index = faiss.read_index(self._faiss_index_file) # Load metadata with open(self._meta_file, "r", encoding="utf-8") as f: stored_dict = json.load(f) # Convert string keys back to int - self._id_to_meta.update({}) + self._id_to_meta = {} for fid_str, meta in stored_dict.items(): fid = int(fid_str) self._id_to_meta[fid] = meta logger.info( - f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}" + f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" ) except Exception as e: logger.error(f"Failed to load Faiss index or metadata: {e}") logger.warning("Starting with an empty Faiss index.") - new_index = faiss.IndexFlatIP(self._dim) - if is_multiprocess: - self._index.value = new_index - else: - self._index = new_index - self._id_to_meta.update({}) + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 43dbcf97..b8fe573d 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -11,25 +11,19 @@ from lightrag.utils import ( ) import pipmaster as pm from lightrag.base import BaseVectorStorage -from .shared_storage import ( - get_storage_lock, - get_namespace_object, - is_multiprocess, - try_initialize_namespace, -) if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB - +from threading import Lock as ThreadLock @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize lock only for file operations - self._storage_lock = get_storage_lock() + self._storage_lock = ThreadLock() # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -45,32 +39,14 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] - # check need_init must before get_namespace_object - need_init = try_initialize_namespace(self.namespace) - self._client = get_namespace_object(self.namespace) - - if need_init: - if is_multiprocess: - self._client.value = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) - else: - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) + with self._storage_lock: + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) def _get_client(self): - """Get the appropriate client instance based on multiprocess mode""" - if is_multiprocess: - return self._client.value + """Check if the shtorage should be reloaded""" return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: @@ -101,8 +77,7 @@ class NanoVectorDBStorage(BaseVectorStorage): if len(embeddings) == len(list_data): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - with self._storage_lock: - results = self._get_client().upsert(datas=list_data) + results = self._get_client().upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -115,21 +90,20 @@ class NanoVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) embedding = embedding[0] - with self._storage_lock: - results = self._get_client().query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - { - **dp, - "id": dp["__id__"], - "distance": dp["__metrics__"], - "created_at": dp.get("__created_at__"), - } - for dp in results - ] + results = self._get_client().query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + { + **dp, + "id": dp["__id__"], + "distance": dp["__metrics__"], + "created_at": dp.get("__created_at__"), + } + for dp in results + ] return results @property @@ -143,8 +117,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - with self._storage_lock: - self._get_client().delete(ids) + self._get_client().delete(ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) @@ -158,37 +131,35 @@ class NanoVectorDBStorage(BaseVectorStorage): f"Attempting to delete entity {entity_name} with ID {entity_id}" ) - with self._storage_lock: - # Check if the entity exists - if self._get_client().get([entity_id]): - self._get_client().delete([entity_id]) - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") + # Check if the entity exists + if self._get_client().get([entity_id]): + self._get_client().delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: try: - with self._storage_lock: - storage = getattr(self._get_client(), "_NanoVectorDB__storage") - relations = [ - dp - for dp in storage["data"] - if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name - ] - logger.debug( - f"Found {len(relations)} relations for entity {entity_name}" - ) - ids_to_delete = [relation["__id__"] for relation in relations] + storage = getattr(self._get_client(), "_NanoVectorDB__storage") + relations = [ + dp + for dp in storage["data"] + if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name + ] + logger.debug( + f"Found {len(relations)} relations for entity {entity_name}" + ) + ids_to_delete = [relation["__id__"] for relation in relations] - if ids_to_delete: - self._get_client().delete(ids_to_delete) - logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" - ) - else: - logger.debug(f"No relations found for entity {entity_name}") + if ids_to_delete: + self._get_client().delete(ids_to_delete) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" + ) + else: + logger.debug(f"No relations found for entity {entity_name}") except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index d42db33a..1f14d5b0 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -6,12 +6,6 @@ import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage -from .shared_storage import ( - get_storage_lock, - get_namespace_object, - is_multiprocess, - try_initialize_namespace, -) import pipmaster as pm @@ -23,7 +17,7 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed - +from threading import Lock as ThreadLock @final @dataclass @@ -78,38 +72,23 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - self._storage_lock = get_storage_lock() - - # check need_init must before get_namespace_object - need_init = try_initialize_namespace(self.namespace) - self._graph = get_namespace_object(self.namespace) - - if need_init: - if is_multiprocess: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - self._graph.value = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - else: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - self._graph = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) + self._storage_lock = ThreadLock() + with self._storage_lock: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + if preloaded_graph is not None: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + else: logger.info("Created new empty graph") - + self._graph = preloaded_graph or nx.Graph() self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } def _get_graph(self): - """Get the appropriate graph instance based on multiprocess mode""" - if is_multiprocess: - return self._graph.value + """Check if the shtorage should be reloaded""" return self._graph async def index_done_callback(self) -> None: @@ -117,54 +96,44 @@ class NetworkXStorage(BaseGraphStorage): NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: - with self._storage_lock: - return self._get_graph().has_node(node_id) + return self._get_graph().has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - with self._storage_lock: - return self._get_graph().has_edge(source_node_id, target_node_id) + return self._get_graph().has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: - with self._storage_lock: - return self._get_graph().nodes.get(node_id) + return self._get_graph().nodes.get(node_id) async def node_degree(self, node_id: str) -> int: - with self._storage_lock: - return self._get_graph().degree(node_id) + return self._get_graph().degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - with self._storage_lock: - return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) + return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - with self._storage_lock: - return self._get_graph().edges.get((source_node_id, target_node_id)) + return self._get_graph().edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - with self._storage_lock: - if self._get_graph().has_node(source_node_id): - return list(self._get_graph().edges(source_node_id)) - return None + if self._get_graph().has_node(source_node_id): + return list(self._get_graph().edges(source_node_id)) + return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - with self._storage_lock: - self._get_graph().add_node(node_id, **node_data) + self._get_graph().add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - with self._storage_lock: - self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) + self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: - with self._storage_lock: - if self._get_graph().has_node(node_id): - self._get_graph().remove_node(node_id) - logger.debug(f"Node {node_id} deleted from the graph.") - else: - logger.warning(f"Node {node_id} not found in the graph for deletion.") + if self._get_graph().has_node(node_id): + self._get_graph().remove_node(node_id) + logger.debug(f"Node {node_id} deleted from the graph.") + else: + logger.warning(f"Node {node_id} not found in the graph for deletion.") async def embed_nodes( self, algorithm: str @@ -175,13 +144,12 @@ class NetworkXStorage(BaseGraphStorage): # TODO: NOT USED async def _node2vec_embed(self): - with self._storage_lock: - graph = self._get_graph() - embeddings, nodes = embed.node2vec_embed( - graph, - **self.global_config["node2vec_params"], - ) - nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] + graph = self._get_graph() + embeddings, nodes = embed.node2vec_embed( + graph, + **self.global_config["node2vec_params"], + ) + nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids def remove_nodes(self, nodes: list[str]): @@ -190,11 +158,10 @@ class NetworkXStorage(BaseGraphStorage): Args: nodes: List of node IDs to be deleted """ - with self._storage_lock: - graph = self._get_graph() - for node in nodes: - if graph.has_node(node): - graph.remove_node(node) + graph = self._get_graph() + for node in nodes: + if graph.has_node(node): + graph.remove_node(node) def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges @@ -202,11 +169,10 @@ class NetworkXStorage(BaseGraphStorage): Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ - with self._storage_lock: - graph = self._get_graph() - for source, target in edges: - if graph.has_edge(source, target): - graph.remove_edge(source, target) + graph = self._get_graph() + for source, target in edges: + if graph.has_edge(source, target): + graph.remove_edge(source, target) async def get_all_labels(self) -> list[str]: """ @@ -214,10 +180,9 @@ class NetworkXStorage(BaseGraphStorage): Returns: [label1, label2, ...] # Alphabetically sorted label list """ - with self._storage_lock: - labels = set() - for node in self._get_graph().nodes(): - labels.add(str(node)) # Add node id as a label + labels = set() + for node in self._get_graph().nodes(): + labels.add(str(node)) # Add node id as a label # Return sorted list return sorted(list(labels)) @@ -239,88 +204,87 @@ class NetworkXStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - with self._storage_lock: - graph = self._get_graph() + graph = self._get_graph() - # Handle special case for "*" label - if node_label == "*": - # For "*", return the entire graph including all nodes and edges - subgraph = ( - graph.copy() - ) # Create a copy to avoid modifying the original graph - else: - # Find nodes with matching node id (partial match) - nodes_to_explore = [] - for n, attr in graph.nodes(data=True): - if node_label in str(n): # Use partial matching - nodes_to_explore.append(n) + # Handle special case for "*" label + if node_label == "*": + # For "*", return the entire graph including all nodes and edges + subgraph = ( + graph.copy() + ) # Create a copy to avoid modifying the original graph + else: + # Find nodes with matching node id (partial match) + nodes_to_explore = [] + for n, attr in graph.nodes(data=True): + if node_label in str(n): # Use partial matching + nodes_to_explore.append(n) - if not nodes_to_explore: - logger.warning(f"No nodes found with label {node_label}") - return result + if not nodes_to_explore: + logger.warning(f"No nodes found with label {node_label}") + return result - # Get subgraph using ego_graph - subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) + # Get subgraph using ego_graph + subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) - # Check if number of nodes exceeds max_graph_nodes - max_graph_nodes = 500 - if len(subgraph.nodes()) > max_graph_nodes: - origin_nodes = len(subgraph.nodes()) - node_degrees = dict(subgraph.degree()) - top_nodes = sorted( - node_degrees.items(), key=lambda x: x[1], reverse=True - )[:max_graph_nodes] - top_node_ids = [node[0] for node in top_nodes] - # Create new subgraph with only top nodes - subgraph = subgraph.subgraph(top_node_ids) - logger.info( - f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" + # Check if number of nodes exceeds max_graph_nodes + max_graph_nodes = 500 + if len(subgraph.nodes()) > max_graph_nodes: + origin_nodes = len(subgraph.nodes()) + node_degrees = dict(subgraph.degree()) + top_nodes = sorted( + node_degrees.items(), key=lambda x: x[1], reverse=True + )[:max_graph_nodes] + top_node_ids = [node[0] for node in top_nodes] + # Create new subgraph with only top nodes + subgraph = subgraph.subgraph(top_node_ids) + logger.info( + f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" + ) + + # Add nodes to result + for node in subgraph.nodes(): + if str(node) in seen_nodes: + continue + + node_data = dict(subgraph.nodes[node]) + # Get entity_type as labels + labels = [] + if "entity_type" in node_data: + if isinstance(node_data["entity_type"], list): + labels.extend(node_data["entity_type"]) + else: + labels.append(node_data["entity_type"]) + + # Create node with properties + node_properties = {k: v for k, v in node_data.items()} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node), labels=[str(node)], properties=node_properties ) + ) + seen_nodes.add(str(node)) - # Add nodes to result - for node in subgraph.nodes(): - if str(node) in seen_nodes: - continue + # Add edges to result + for edge in subgraph.edges(): + source, target = edge + edge_id = f"{source}-{target}" + if edge_id in seen_edges: + continue - node_data = dict(subgraph.nodes[node]) - # Get entity_type as labels - labels = [] - if "entity_type" in node_data: - if isinstance(node_data["entity_type"], list): - labels.extend(node_data["entity_type"]) - else: - labels.append(node_data["entity_type"]) + edge_data = dict(subgraph.edges[edge]) - # Create node with properties - node_properties = {k: v for k, v in node_data.items()} - - result.nodes.append( - KnowledgeGraphNode( - id=str(node), labels=[str(node)], properties=node_properties - ) + # Create edge with complete information + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source), + target=str(target), + properties=edge_data, ) - seen_nodes.add(str(node)) - - # Add edges to result - for edge in subgraph.edges(): - source, target = edge - edge_id = f"{source}-{target}" - if edge_id in seen_edges: - continue - - edge_data = dict(subgraph.edges[edge]) - - # Create edge with complete information - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source), - target=str(target), - properties=edge_data, - ) - ) - seen_edges.add(edge_id) + ) + seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index c57771ba..681ef064 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -20,15 +20,12 @@ LockType = Union[ProcessLock, ThreadLock] _manager = None _initialized = None is_multiprocess = None +_global_lock: Optional[LockType] = None # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None -_share_objects: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized -_global_lock: Optional[LockType] = None - - def initialize_share_data(workers: int = 1): """ Initialize shared storage data for single or multi-process mode. @@ -53,7 +50,6 @@ def initialize_share_data(workers: int = 1): is_multiprocess, \ _global_lock, \ _shared_dicts, \ - _share_objects, \ _init_flags, \ _initialized @@ -72,7 +68,6 @@ def initialize_share_data(workers: int = 1): _global_lock = _manager.Lock() # Create shared dictionaries with manager _shared_dicts = _manager.dict() - _share_objects = _manager.dict() _init_flags = ( _manager.dict() ) # Use shared dictionary to store initialization flags @@ -83,7 +78,6 @@ def initialize_share_data(workers: int = 1): is_multiprocess = False _global_lock = ThreadLock() _shared_dicts = {} - _share_objects = {} _init_flags = {} direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") @@ -99,11 +93,7 @@ def try_initialize_namespace(namespace: str) -> bool: global _init_flags, _manager if _init_flags is None: - direct_log( - f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}", - level="ERROR", - ) - raise ValueError("Shared dictionaries not initialized") + raise ValueError("Try to create nanmespace before Shared-Data is initialized") if namespace not in _init_flags: _init_flags[namespace] = True @@ -113,43 +103,9 @@ def try_initialize_namespace(namespace: str) -> bool: return False -def _get_global_lock() -> LockType: - return _global_lock - - def get_storage_lock() -> LockType: """return storage lock for data consistency""" - return _get_global_lock() - - -def get_scan_lock() -> LockType: - """return scan_progress lock for data consistency""" - return get_storage_lock() - - -def get_namespace_object(namespace: str) -> Any: - """Get an object for specific namespace""" - - if _share_objects is None: - direct_log( - f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", - level="ERROR", - ) - raise ValueError("Shared dictionaries not initialized") - - lock = _get_global_lock() - with lock: - if namespace not in _share_objects: - if namespace not in _share_objects: - if is_multiprocess: - _share_objects[namespace] = _manager.Value("O", None) - else: - _share_objects[namespace] = None - direct_log( - f"Created namespace: {namespace}(type={type(_share_objects[namespace])})" - ) - - return _share_objects[namespace] + return _global_lock def get_namespace_data(namespace: str) -> Dict[str, Any]: @@ -161,7 +117,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: ) raise ValueError("Shared dictionaries not initialized") - lock = _get_global_lock() + lock = get_storage_lock() with lock: if namespace not in _shared_dicts: if is_multiprocess and _manager is not None: @@ -175,11 +131,6 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: return _shared_dicts[namespace] -def get_scan_progress() -> Dict[str, Any]: - """get storage space for document scanning progress data""" - return get_namespace_data("scan_progress") - - def finalize_share_data(): """ Release shared resources and clean up. @@ -195,7 +146,6 @@ def finalize_share_data(): is_multiprocess, \ _global_lock, \ _shared_dicts, \ - _share_objects, \ _init_flags, \ _initialized @@ -216,8 +166,6 @@ def finalize_share_data(): # Clear shared dictionaries first if _shared_dicts is not None: _shared_dicts.clear() - if _share_objects is not None: - _share_objects.clear() if _init_flags is not None: _init_flags.clear() @@ -234,7 +182,6 @@ def finalize_share_data(): _initialized = None is_multiprocess = None _shared_dicts = None - _share_objects = None _init_flags = None _global_lock = None