From 2752a764ae39acb824cf519caaeabae689729a6b Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 05:38:38 +0800 Subject: [PATCH] Refactor storage implementations to support both single and multi-process modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add shared storage management module • Support process/thread lock based on mode --- lightrag/api/lightrag_server.py | 9 +- lightrag/api/routers/document_routes.py | 66 +++-- lightrag/api/utils_api.py | 61 ----- lightrag/kg/faiss_impl.py | 309 +++++++++++------------ lightrag/kg/json_doc_status_impl.py | 108 +++----- lightrag/kg/json_kv_impl.py | 91 +++---- lightrag/kg/nano_vector_db_impl.py | 165 ++++++------ lightrag/kg/networkx_impl.py | 320 ++++++++++++------------ lightrag/kg/shared_storage.py | 94 +++++++ lightrag/lightrag.py | 8 +- 10 files changed, 608 insertions(+), 623 deletions(-) create mode 100644 lightrag/kg/shared_storage.py diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 62cb24db..65227e97 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -406,9 +406,6 @@ def create_app(args): def get_application(): """Factory function for creating the FastAPI application""" - from .utils_api import initialize_manager - initialize_manager() - # Get args from environment variable args_json = os.environ.get('LIGHTRAG_ARGS') if not args_json: @@ -428,6 +425,12 @@ def main(): # Save args to environment variable for child processes os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) + if args.workers > 1: + from lightrag.kg.shared_storage import initialize_manager + initialize_manager() + import lightrag.kg.shared_storage as shared_storage + shared_storage.is_multiprocess = True + # Configure uvicorn logging logging.config.dictConfig({ "version": 1, diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ea6bf29d..c084023d 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -18,12 +18,10 @@ from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from ..utils_api import ( - get_api_key_dependency, - scan_progress, - update_scan_progress_if_not_scanning, - update_scan_progress, - reset_scan_progress, +from ..utils_api import get_api_key_dependency +from lightrag.kg.shared_storage import ( + get_scan_progress, + get_scan_lock, ) @@ -378,23 +376,51 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): """Background task to scan and index documents""" - if not update_scan_progress_if_not_scanning(): - ASCIIColors.info( - "Skip document scanning(another scanning is active)" - ) - return + scan_progress = get_scan_progress() + scan_lock = get_scan_lock() + + with scan_lock: + if scan_progress["is_scanning"]: + ASCIIColors.info( + "Skip document scanning(another scanning is active)" + ) + return + scan_progress.update({ + "is_scanning": True, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) try: new_files = doc_manager.scan_directory_for_new_files() total_files = len(new_files) - update_scan_progress("", total_files, 0) # Initialize progress + scan_progress.update({ + "current_file": "", + "total_files": total_files, + "indexed_count": 0, + "progress": 0, + }) logging.info(f"Found {total_files} new files to index.") for idx, file_path in enumerate(new_files): try: - update_scan_progress(os.path.basename(file_path), total_files, idx) + progress = (idx / total_files * 100) if total_files > 0 else 0 + scan_progress.update({ + "current_file": os.path.basename(file_path), + "indexed_count": idx, + "progress": progress, + }) + await pipeline_index_file(rag, file_path) - update_scan_progress(os.path.basename(file_path), total_files, idx + 1) + + progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0 + scan_progress.update({ + "current_file": os.path.basename(file_path), + "indexed_count": idx + 1, + "progress": progress, + }) except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") @@ -402,7 +428,13 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): except Exception as e: logging.error(f"Error during scanning process: {str(e)}") finally: - reset_scan_progress() + scan_progress.update({ + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) def create_document_routes( @@ -427,7 +459,7 @@ def create_document_routes( return {"status": "scanning_started"} @router.get("/scan-progress") - async def get_scan_progress(): + async def get_scanning_progress(): """ Get the current progress of the document scanning process. @@ -439,7 +471,7 @@ def create_document_routes( - total_files: Total number of files to process - progress: Percentage of completion """ - return dict(scan_progress) + return dict(get_scan_progress()) @router.post("/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir( diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 2544276a..6b501e64 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -6,7 +6,6 @@ import os import argparse from typing import Optional import sys -from multiprocessing import Manager from ascii_colors import ASCIIColors from lightrag.api import __api_version__ from fastapi import HTTPException, Security @@ -17,66 +16,6 @@ from starlette.status import HTTP_403_FORBIDDEN # Load environment variables load_dotenv(override=True) -# Global variables for manager and shared state -manager = None -scan_progress = None -scan_lock = None - -def initialize_manager(): - """Initialize manager and shared state for cross-process communication""" - global manager, scan_progress, scan_lock - if manager is None: - manager = Manager() - scan_progress = manager.dict({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - scan_lock = manager.Lock() - -def update_scan_progress_if_not_scanning(): - """ - Atomically check if scanning is not in progress and update scan_progress if it's not. - Returns True if the update was successful, False if scanning was already in progress. - """ - with scan_lock: - if not scan_progress["is_scanning"]: - scan_progress.update({ - "is_scanning": True, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - return True - return False - -def update_scan_progress(current_file: str, total_files: int, indexed_count: int): - """ - Atomically update scan progress information. - """ - progress = (indexed_count / total_files * 100) if total_files > 0 else 0 - scan_progress.update({ - "current_file": current_file, - "indexed_count": indexed_count, - "total_files": total_files, - "progress": progress, - }) - -def reset_scan_progress(): - """ - Atomically reset scan progress to initial state. - """ - scan_progress.update({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - class OllamaServerInfos: # Constants for emulated Ollama model information diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 2e129472..8c9c52c4 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -2,48 +2,21 @@ import os import time import asyncio from typing import Any, final -import threading import json import numpy as np from dataclasses import dataclass import pipmaster as pm -from lightrag.api.utils_api import manager as main_process_manager -from lightrag.utils import ( - logger, - compute_mdhash_id, -) -from lightrag.base import ( - BaseVectorStorage, -) +from lightrag.utils import logger,compute_mdhash_id +from lightrag.base import BaseVectorStorage +from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess if not pm.is_installed("faiss"): pm.install("faiss") import faiss # type: ignore -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_indices = None -_shared_meta = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_indices, _shared_meta - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_indices = _manager.dict() - _shared_meta = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager - @final @dataclass @@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim + self._storage_lock = get_storage_lock() - # Ensure manager is initialized - _get_manager() + self._index = get_namespace_object('faiss_indices') + self._id_to_meta = get_namespace_data('faiss_meta') - # Get or create namespace index and metadata - if self.namespace not in _shared_indices: - with _init_lock: - if self.namespace not in _shared_indices: - try: - # Create an empty Faiss index for inner product - index = faiss.IndexFlatIP(self._dim) - meta = {} - - # Load existing index if available - if os.path.exists(self._faiss_index_file): - try: - index = faiss.read_index(self._faiss_index_file) - with open(self._meta_file, "r", encoding="utf-8") as f: - stored_dict = json.load(f) - # Convert string keys back to int - meta = {int(k): v for k, v in stored_dict.items()} - logger.info( - f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}" - ) - except Exception as e: - logger.error(f"Failed to load Faiss index or metadata: {e}") - logger.warning("Starting with an empty Faiss index.") - index = faiss.IndexFlatIP(self._dim) - meta = {} - - _shared_indices[self.namespace] = index - _shared_meta[self.namespace] = meta - except Exception as e: - logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}") - raise RuntimeError(f"Faiss index initialization failed: {e}") - - try: - self._index = _shared_indices[self.namespace] - self._id_to_meta = _shared_meta[self.namespace] - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + with self._storage_lock: + if is_multiprocess: + if self._index.value is None: + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index.value = faiss.IndexFlatIP(self._dim) + else: + if self._index is None: + self._index = faiss.IndexFlatIP(self._dim) + + # Keep a local store for metadata, IDs, etc. + # Maps → metadata (including your original ID). + self._id_to_meta.update({}) + + # Attempt to load an existing index + metadata from disk + self._load_faiss_index() + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ @@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage): # Normalize embeddings for cosine similarity (in-place) faiss.normalize_L2(embeddings) - # Upsert logic: - # 1. Identify which vectors to remove if they exist - # 2. Remove them - # 3. Add the new vectors - existing_ids_to_remove = [] - for meta, emb in zip(list_data, embeddings): - faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) - if faiss_internal_id is not None: - existing_ids_to_remove.append(faiss_internal_id) + with self._storage_lock: + # Upsert logic: + # 1. Identify which vectors to remove if they exist + # 2. Remove them + # 3. Add the new vectors + existing_ids_to_remove = [] + for meta, emb in zip(list_data, embeddings): + faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) + if faiss_internal_id is not None: + existing_ids_to_remove.append(faiss_internal_id) - if existing_ids_to_remove: - self._remove_faiss_ids(existing_ids_to_remove) + if existing_ids_to_remove: + self._remove_faiss_ids(existing_ids_to_remove) - # Step 2: Add new vectors - start_idx = self._index.ntotal - self._index.add(embeddings) + # Step 2: Add new vectors + start_idx = (self._index.value if is_multiprocess else self._index).ntotal + if is_multiprocess: + self._index.value.add(embeddings) + else: + self._index.add(embeddings) - # Step 3: Store metadata + vector for each new ID - for i, meta in enumerate(list_data): - fid = start_idx + i - # Store the raw vector so we can rebuild if something is removed - meta["__vector__"] = embeddings[i].tolist() - self._id_to_meta[fid] = meta + # Step 3: Store metadata + vector for each new ID + for i, meta in enumerate(list_data): + fid = start_idx + i + # Store the raw vector so we can rebuild if something is removed + meta["__vector__"] = embeddings[i].tolist() + self._id_to_meta.update({fid: meta}) - logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") - return [m["__id__"] for m in list_data] + logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") + return [m["__id__"] for m in list_data] async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """ @@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage): ) # Perform the similarity search - distances, indices = self._index.search(embedding, top_k) + with self._storage_lock: + distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k) - distances = distances[0] - indices = indices[0] + distances = distances[0] + indices = indices[0] - results = [] - for dist, idx in zip(distances, indices): - if idx == -1: - # Faiss returns -1 if no neighbor - continue + results = [] + for dist, idx in zip(distances, indices): + if idx == -1: + # Faiss returns -1 if no neighbor + continue - # Cosine similarity threshold - if dist < self.cosine_better_than_threshold: - continue + # Cosine similarity threshold + if dist < self.cosine_better_than_threshold: + continue - meta = self._id_to_meta.get(idx, {}) - results.append( - { - **meta, - "id": meta.get("__id__"), - "distance": float(dist), - "created_at": meta.get("__created_at__"), - } - ) + meta = self._id_to_meta.get(idx, {}) + results.append( + { + **meta, + "id": meta.get("__id__"), + "distance": float(dist), + "created_at": meta.get("__created_at__"), + } + ) - return results + return results @property def client_storage(self): # Return whatever structure LightRAG might need for debugging - return {"data": list(self._id_to_meta.values())} + with self._storage_lock: + return {"data": list(self._id_to_meta.values())} async def delete(self, ids: list[str]): """ Delete vectors for the provided custom IDs. """ logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") - to_remove = [] - for cid in ids: - fid = self._find_faiss_id_by_custom_id(cid) - if fid is not None: - to_remove.append(fid) + with self._storage_lock: + to_remove = [] + for cid in ids: + fid = self._find_faiss_id_by_custom_id(cid) + if fid is not None: + to_remove.append(fid) - if to_remove: - self._remove_faiss_ids(to_remove) - logger.info( - f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" - ) + if to_remove: + self._remove_faiss_ids(to_remove) + logger.debug( + f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" + ) async def delete_entity(self, entity_name: str) -> None: entity_id = compute_mdhash_id(entity_name, prefix="ent-") @@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage): Delete relations for a given entity by scanning metadata. """ logger.debug(f"Searching relations for entity {entity_name}") - relations = [] - for fid, meta in self._id_to_meta.items(): - if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: - relations.append(fid) + with self._storage_lock: + relations = [] + for fid, meta in self._id_to_meta.items(): + if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: + relations.append(fid) - logger.debug(f"Found {len(relations)} relations for {entity_name}") - if relations: - self._remove_faiss_ids(relations) - logger.debug(f"Deleted {len(relations)} relations for {entity_name}") + logger.debug(f"Found {len(relations)} relations for {entity_name}") + if relations: + self._remove_faiss_ids(relations) + logger.debug(f"Deleted {len(relations)} relations for {entity_name}") async def index_done_callback(self) -> None: - self._save_faiss_index() + with self._storage_lock: + self._save_faiss_index() # -------------------------------------------------------------------------------- # Internal helper methods @@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage): """ Return the Faiss internal ID for a given custom ID, or None if not found. """ - for fid, meta in self._id_to_meta.items(): - if meta.get("__id__") == custom_id: - return fid - return None + with self._storage_lock: + for fid, meta in self._id_to_meta.items(): + if meta.get("__id__") == custom_id: + return fid + return None def _remove_faiss_ids(self, fid_list): """ @@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage): Because IndexFlatIP doesn't support 'removals', we rebuild the index excluding those vectors. """ - keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] + with self._storage_lock: + keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] - # Rebuild the index - vectors_to_keep = [] - new_id_to_meta = {} - for new_fid, old_fid in enumerate(keep_fids): - vec_meta = self._id_to_meta[old_fid] - vectors_to_keep.append(vec_meta["__vector__"]) # stored as list - new_id_to_meta[new_fid] = vec_meta + # Rebuild the index + vectors_to_keep = [] + new_id_to_meta = {} + for new_fid, old_fid in enumerate(keep_fids): + vec_meta = self._id_to_meta[old_fid] + vectors_to_keep.append(vec_meta["__vector__"]) # stored as list + new_id_to_meta[new_fid] = vec_meta - # Re-init index - self._index = faiss.IndexFlatIP(self._dim) - if vectors_to_keep: - arr = np.array(vectors_to_keep, dtype=np.float32) - self._index.add(arr) + # Re-init index + new_index = faiss.IndexFlatIP(self._dim) + if vectors_to_keep: + arr = np.array(vectors_to_keep, dtype=np.float32) + new_index.add(arr) + if is_multiprocess: + self._index.value = new_index + else: + self._index = new_index - self._id_to_meta = new_id_to_meta + self._id_to_meta.update(new_id_to_meta) def _save_faiss_index(self): """ Save the current Faiss index + metadata to disk so it can persist across runs. """ - faiss.write_index(self._index, self._faiss_index_file) + with self._storage_lock: + faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file) - # Save metadata dict to JSON. Convert all keys to strings for JSON storage. - # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } - # We'll keep the int -> dict, but JSON requires string keys. - serializable_dict = {} - for fid, meta in self._id_to_meta.items(): - serializable_dict[str(fid)] = meta + # Save metadata dict to JSON. Convert all keys to strings for JSON storage. + # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } + # We'll keep the int -> dict, but JSON requires string keys. + serializable_dict = {} + for fid, meta in self._id_to_meta.items(): + serializable_dict[str(fid)] = meta - with open(self._meta_file, "w", encoding="utf-8") as f: - json.dump(serializable_dict, f) + with open(self._meta_file, "w", encoding="utf-8") as f: + json.dump(serializable_dict, f) def _load_faiss_index(self): """ @@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage): try: # Load the Faiss index - self._index = faiss.read_index(self._faiss_index_file) + loaded_index = faiss.read_index(self._faiss_index_file) + if is_multiprocess: + self._index.value = loaded_index + else: + self._index = loaded_index + # Load metadata with open(self._meta_file, "r", encoding="utf-8") as f: stored_dict = json.load(f) # Convert string keys back to int - self._id_to_meta = {} + self._id_to_meta.update({}) for fid_str, meta in stored_dict.items(): fid = int(fid_str) self._id_to_meta[fid] = meta logger.info( - f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" + f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}" ) except Exception as e: logger.error(f"Failed to load Faiss index or metadata: {e}") logger.warning("Starting with an empty Faiss index.") - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta = {} + new_index = faiss.IndexFlatIP(self._dim) + if is_multiprocess: + self._index.value = new_index + else: + self._index = new_index + self._id_to_meta.update({}) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index dd3a7b64..50451f95 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import os from typing import Any, Union, final -import threading from lightrag.base import ( DocProcessingStatus, @@ -13,26 +12,7 @@ from lightrag.utils import ( logger, write_json, ) -from lightrag.api.utils_api import manager as main_process_manager - -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_doc_status_data = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_doc_status_data - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_doc_status_data = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager +from .shared_storage import get_namespace_data, get_storage_lock @final @@ -43,45 +23,32 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - - # Ensure manager is initialized - _get_manager() - - # Get or create namespace data - if self.namespace not in _shared_doc_status_data: - with _init_lock: - if self.namespace not in _shared_doc_status_data: - try: - initial_data = load_json(self._file_name) or {} - _shared_doc_status_data[self.namespace] = initial_data - except Exception as e: - logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") - raise RuntimeError(f"Shared data initialization failed: {e}") - - try: - self._data = _shared_doc_status_data[self.namespace] - logger.info(f"Loaded document status storage with {len(self._data)} records") - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + self._storage_lock = get_storage_lock() + self._data = get_namespace_data(self.namespace) + with self._storage_lock: + self._data.update(load_json(self._file_name) or {}) + logger.info(f"Loaded document status storage with {len(self._data)} records") async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return set(keys) - set(self._data.keys()) + with self._storage_lock: + return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] - for id in ids: - data = self._data.get(id, None) - if data: - result.append(data) + with self._storage_lock: + for id in ids: + data = self._data.get(id, None) + if data: + result.append(data) return result async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" counts = {status.value: 0 for status in DocStatus} - for doc in self._data.values(): - counts[doc["status"]] += 1 + with self._storage_lock: + for doc in self._data.values(): + counts[doc["status"]] += 1 return counts async def get_docs_by_status( @@ -89,39 +56,46 @@ class JsonDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" result = {} - for k, v in self._data.items(): - if v["status"] == status.value: - try: - # Make a copy of the data to avoid modifying the original - data = v.copy() - # If content is missing, use content_summary as content - if "content" not in data and "content_summary" in data: - data["content"] = data["content_summary"] - result[k] = DocProcessingStatus(**data) - except KeyError as e: - logger.error(f"Missing required field for document {k}: {e}") - continue + with self._storage_lock: + for k, v in self._data.items(): + if v["status"] == status.value: + try: + # Make a copy of the data to avoid modifying the original + data = v.copy() + # If content is missing, use content_summary as content + if "content" not in data and "content_summary" in data: + data["content"] = data["content_summary"] + result[k] = DocProcessingStatus(**data) + except KeyError as e: + logger.error(f"Missing required field for document {k}: {e}") + continue return result async def index_done_callback(self) -> None: - write_json(self._data, self._file_name) + # 文件写入需要加锁,防止多个进程同时写入导致文件损坏 + with self._storage_lock: + write_json(self._data, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - self._data.update(data) + with self._storage_lock: + self._data.update(data) await self.index_done_callback() async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.get(id) + with self._storage_lock: + return self._data.get(id) async def delete(self, doc_ids: list[str]): - for doc_id in doc_ids: - self._data.pop(doc_id, None) + with self._storage_lock: + for doc_id in doc_ids: + self._data.pop(doc_id, None) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" - self._data.clear() + with self._storage_lock: + self._data.clear() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index f5a8b488..a53ac8f0 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,8 +1,6 @@ -import asyncio import os from dataclasses import dataclass from typing import Any, final -import threading from lightrag.base import ( BaseKVStorage, @@ -12,26 +10,7 @@ from lightrag.utils import ( logger, write_json, ) -from lightrag.api.utils_api import manager as main_process_manager - -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_kv_data = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_kv_data - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_kv_data = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager +from .shared_storage import get_namespace_data, get_storage_lock @final @@ -39,57 +18,49 @@ def _get_manager(): class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._lock = asyncio.Lock() - - # Ensure manager is initialized - _get_manager() - - # Get or create namespace data - if self.namespace not in _shared_kv_data: - with _init_lock: - if self.namespace not in _shared_kv_data: - try: - initial_data = load_json(self._file_name) or {} - _shared_kv_data[self.namespace] = initial_data - except Exception as e: - logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") - raise RuntimeError(f"Shared data initialization failed: {e}") - - try: - self._data = _shared_kv_data[self.namespace] - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + self._storage_lock = get_storage_lock() + self._data = get_namespace_data(self.namespace) + with self._storage_lock: + if not self._data: + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data: dict[str, Any] = load_json(self._file_name) or {} + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + async def index_done_callback(self) -> None: - write_json(self._data, self._file_name) + # 文件写入需要加锁,防止多个进程同时写入导致文件损坏 + with self._storage_lock: + write_json(self._data, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: - return self._data.get(id) + with self._storage_lock: + return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return [ - ( - {k: v for k, v in self._data[id].items()} - if self._data.get(id, None) - else None - ) - for id in ids - ] + with self._storage_lock: + return [ + ( + {k: v for k, v in self._data[id].items()} + if self._data.get(id, None) + else None + ) + for id in ids + ] async def filter_keys(self, keys: set[str]) -> set[str]: - return set(keys) - set(self._data.keys()) + with self._storage_lock: + return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) + with self._storage_lock: + left_data = {k: v for k, v in data.items() if k not in self._data} + self._data.update(left_data) async def delete(self, ids: list[str]) -> None: - for doc_id in ids: - self._data.pop(doc_id, None) + with self._storage_lock: + for doc_id in ids: + self._data.pop(doc_id, None) await self.index_done_callback() diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 7c15142e..07f8d367 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -3,50 +3,29 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np -import threading import time from lightrag.utils import ( logger, compute_mdhash_id, ) -from lightrag.api.utils_api import manager as main_process_manager import pipmaster as pm -from lightrag.base import ( - BaseVectorStorage, -) +from lightrag.base import BaseVectorStorage +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_vector_clients = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_vector_clients - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_vector_clients = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager - @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize lock only for file operations - self._save_lock = asyncio.Lock() + self._storage_lock = get_storage_lock() + # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -61,28 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] - # Ensure manager is initialized - _get_manager() + self._client = get_namespace_object(self.namespace) - # Get or create namespace client - if self.namespace not in _shared_vector_clients: - with _init_lock: - if self.namespace not in _shared_vector_clients: - try: - client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name - ) - _shared_vector_clients[self.namespace] = client - except Exception as e: - logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}") - raise RuntimeError(f"Vector DB client initialization failed: {e}") + with self._storage_lock: + if is_multiprocess: + if self._client.value is None: + self._client.value = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) + else: + if self._client is None: + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) - try: - self._client = _shared_vector_clients[self.namespace] - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + logger.info(f"Initialized vector DB client for namespace {self.namespace}") + + def _get_client(self): + """Get the appropriate client instance based on multiprocess mode""" + if is_multiprocess: + return self._client.value + return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") @@ -104,6 +82,7 @@ class NanoVectorDBStorage(BaseVectorStorage): for i in range(0, len(contents), self._max_batch_size) ] + # Execute embedding outside of lock to avoid long lock times embedding_tasks = [self.embedding_func(batch) for batch in batches] embeddings_list = await asyncio.gather(*embedding_tasks) @@ -111,7 +90,9 @@ class NanoVectorDBStorage(BaseVectorStorage): if len(embeddings) == len(list_data): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) + with self._storage_lock: + client = self._get_client() + results = client.upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -120,27 +101,32 @@ class NanoVectorDBStorage(BaseVectorStorage): ) async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + # Execute embedding outside of lock to avoid long lock times embedding = await self.embedding_func([query]) embedding = embedding[0] - results = self._client.query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - { - **dp, - "id": dp["__id__"], - "distance": dp["__metrics__"], - "created_at": dp.get("__created_at__"), - } - for dp in results - ] + + with self._storage_lock: + client = self._get_client() + results = client.query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + { + **dp, + "id": dp["__id__"], + "distance": dp["__metrics__"], + "created_at": dp.get("__created_at__"), + } + for dp in results + ] return results @property def client_storage(self): - return getattr(self._client, "_NanoVectorDB__storage") + client = self._get_client() + return getattr(client, "_NanoVectorDB__storage") async def delete(self, ids: list[str]): """Delete vectors with specified IDs @@ -149,8 +135,10 @@ class NanoVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - self._client.delete(ids) - logger.info( + with self._storage_lock: + client = self._get_client() + client.delete(ids) + logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) except Exception as e: @@ -162,35 +150,42 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.debug( f"Attempting to delete entity {entity_name} with ID {entity_id}" ) - # Check if the entity exists - if self._client.get([entity_id]): - await self.delete([entity_id]) - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") + + with self._storage_lock: + client = self._get_client() + # Check if the entity exists + if client.get([entity_id]): + client.delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: try: - relations = [ - dp - for dp in self.client_storage["data"] - if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name - ] - logger.debug(f"Found {len(relations)} relations for entity {entity_name}") - ids_to_delete = [relation["__id__"] for relation in relations] + with self._storage_lock: + client = self._get_client() + storage = getattr(client, "_NanoVectorDB__storage") + relations = [ + dp + for dp in storage["data"] + if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name + ] + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") + ids_to_delete = [relation["__id__"] for relation in relations] - if ids_to_delete: - await self.delete(ids_to_delete) - logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" - ) - else: - logger.debug(f"No relations found for entity {entity_name}") + if ids_to_delete: + client.delete(ids_to_delete) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" + ) + else: + logger.debug(f"No relations found for entity {entity_name}") except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: - async with self._save_lock: - self._client.save() + with self._storage_lock: + client = self._get_client() + client.save() diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index f3dd92dc..df07499b 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,18 +1,13 @@ import os from dataclasses import dataclass from typing import Any, final -import threading import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from lightrag.utils import ( - logger, -) -from lightrag.api.utils_api import manager as main_process_manager +from lightrag.utils import logger +from lightrag.base import BaseGraphStorage +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess -from lightrag.base import ( - BaseGraphStorage, -) import pipmaster as pm if not pm.is_installed("networkx"): @@ -24,25 +19,6 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_graphs = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_graphs - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_graphs = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager - @final @dataclass @@ -97,76 +73,98 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - - # Ensure manager is initialized - _get_manager() - - # Get or create namespace graph - if self.namespace not in _shared_graphs: - with _init_lock: - if self.namespace not in _shared_graphs: - try: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - _shared_graphs[self.namespace] = preloaded_graph or nx.Graph() - except Exception as e: - logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}") - raise RuntimeError(f"Graph initialization failed: {e}") - - try: - self._graph = _shared_graphs[self.namespace] - self._node_embed_algorithms = { + self._storage_lock = get_storage_lock() + self._graph = get_namespace_object(self.namespace) + with self._storage_lock: + if is_multiprocess: + if self._graph.value is None: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + self._graph.value = preloaded_graph or nx.Graph() + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + else: + if self._graph is None: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + self._graph = preloaded_graph or nx.Graph() + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + + self._node_embed_algorithms = { "node2vec": self._node2vec_embed, - } - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + } + + def _get_graph(self): + """Get the appropriate graph instance based on multiprocess mode""" + if is_multiprocess: + return self._graph.value + return self._graph async def index_done_callback(self) -> None: - NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) + with self._storage_lock: + graph = self._get_graph() + NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: - return self._graph.has_node(node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._graph.has_edge(source_node_id, target_node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: - return self._graph.nodes.get(node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: - return self._graph.degree(node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._graph.degree(src_id) + self._graph.degree(tgt_id) + with self._storage_lock: + graph = self._get_graph() + return graph.degree(src_id) + graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - return self._graph.edges.get((source_node_id, target_node_id)) + with self._storage_lock: + graph = self._get_graph() + return graph.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - if self._graph.has_node(source_node_id): - return list(self._graph.edges(source_node_id)) - return None + with self._storage_lock: + graph = self._get_graph() + if graph.has_node(source_node_id): + return list(graph.edges(source_node_id)) + return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - self._graph.add_node(node_id, **node_data) + with self._storage_lock: + graph = self._get_graph() + graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - self._graph.add_edge(source_node_id, target_node_id, **edge_data) + with self._storage_lock: + graph = self._get_graph() + graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: - if self._graph.has_node(node_id): - self._graph.remove_node(node_id) - logger.info(f"Node {node_id} deleted from the graph.") - else: - logger.warning(f"Node {node_id} not found in the graph for deletion.") + with self._storage_lock: + graph = self._get_graph() + if graph.has_node(node_id): + graph.remove_node(node_id) + logger.debug(f"Node {node_id} deleted from the graph.") + else: + logger.warning(f"Node {node_id} not found in the graph for deletion.") async def embed_nodes( self, algorithm: str @@ -175,14 +173,15 @@ class NetworkXStorage(BaseGraphStorage): raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() - # @TODO: NOT USED + # TODO: NOT USED async def _node2vec_embed(self): - embeddings, nodes = embed.node2vec_embed( - self._graph, - **self.global_config["node2vec_params"], - ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + with self._storage_lock: + graph = self._get_graph() + embeddings, nodes = embed.node2vec_embed( + graph, + **self.global_config["node2vec_params"], + ) + nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids def remove_nodes(self, nodes: list[str]): @@ -191,9 +190,11 @@ class NetworkXStorage(BaseGraphStorage): Args: nodes: List of node IDs to be deleted """ - for node in nodes: - if self._graph.has_node(node): - self._graph.remove_node(node) + with self._storage_lock: + graph = self._get_graph() + for node in nodes: + if graph.has_node(node): + graph.remove_node(node) def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges @@ -201,9 +202,11 @@ class NetworkXStorage(BaseGraphStorage): Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ - for source, target in edges: - if self._graph.has_edge(source, target): - self._graph.remove_edge(source, target) + with self._storage_lock: + graph = self._get_graph() + for source, target in edges: + if graph.has_edge(source, target): + graph.remove_edge(source, target) async def get_all_labels(self) -> list[str]: """ @@ -211,9 +214,11 @@ class NetworkXStorage(BaseGraphStorage): Returns: [label1, label2, ...] # Alphabetically sorted label list """ - labels = set() - for node in self._graph.nodes(): - labels.add(str(node)) # Add node id as a label + with self._storage_lock: + graph = self._get_graph() + labels = set() + for node in graph.nodes(): + labels.add(str(node)) # Add node id as a label # Return sorted list return sorted(list(labels)) @@ -235,87 +240,86 @@ class NetworkXStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - # Handle special case for "*" label - if node_label == "*": - # For "*", return the entire graph including all nodes and edges - subgraph = ( - self._graph.copy() - ) # Create a copy to avoid modifying the original graph - else: - # Find nodes with matching node id (partial match) - nodes_to_explore = [] - for n, attr in self._graph.nodes(data=True): - if node_label in str(n): # Use partial matching - nodes_to_explore.append(n) + with self._storage_lock: + graph = self._get_graph() + + # Handle special case for "*" label + if node_label == "*": + # For "*", return the entire graph including all nodes and edges + subgraph = graph.copy() # Create a copy to avoid modifying the original graph + else: + # Find nodes with matching node id (partial match) + nodes_to_explore = [] + for n, attr in graph.nodes(data=True): + if node_label in str(n): # Use partial matching + nodes_to_explore.append(n) - if not nodes_to_explore: - logger.warning(f"No nodes found with label {node_label}") - return result + if not nodes_to_explore: + logger.warning(f"No nodes found with label {node_label}") + return result - # Get subgraph using ego_graph - subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth) + # Get subgraph using ego_graph + subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) - # Check if number of nodes exceeds max_graph_nodes - max_graph_nodes = 500 - if len(subgraph.nodes()) > max_graph_nodes: - origin_nodes = len(subgraph.nodes()) - node_degrees = dict(subgraph.degree()) - top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ - :max_graph_nodes - ] - top_node_ids = [node[0] for node in top_nodes] - # Create new subgraph with only top nodes - subgraph = subgraph.subgraph(top_node_ids) - logger.info( - f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" - ) - - # Add nodes to result - for node in subgraph.nodes(): - if str(node) in seen_nodes: - continue - - node_data = dict(subgraph.nodes[node]) - # Get entity_type as labels - labels = [] - if "entity_type" in node_data: - if isinstance(node_data["entity_type"], list): - labels.extend(node_data["entity_type"]) - else: - labels.append(node_data["entity_type"]) - - # Create node with properties - node_properties = {k: v for k, v in node_data.items()} - - result.nodes.append( - KnowledgeGraphNode( - id=str(node), labels=[str(node)], properties=node_properties + # Check if number of nodes exceeds max_graph_nodes + max_graph_nodes = 500 + if len(subgraph.nodes()) > max_graph_nodes: + origin_nodes = len(subgraph.nodes()) + node_degrees = dict(subgraph.degree()) + top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ + :max_graph_nodes + ] + top_node_ids = [node[0] for node in top_nodes] + # Create new subgraph with only top nodes + subgraph = subgraph.subgraph(top_node_ids) + logger.info( + f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" ) - ) - seen_nodes.add(str(node)) - # Add edges to result - for edge in subgraph.edges(): - source, target = edge - edge_id = f"{source}-{target}" - if edge_id in seen_edges: - continue + # Add nodes to result + for node in subgraph.nodes(): + if str(node) in seen_nodes: + continue - edge_data = dict(subgraph.edges[edge]) + node_data = dict(subgraph.nodes[node]) + # Get entity_type as labels + labels = [] + if "entity_type" in node_data: + if isinstance(node_data["entity_type"], list): + labels.extend(node_data["entity_type"]) + else: + labels.append(node_data["entity_type"]) - # Create edge with complete information - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source), - target=str(target), - properties=edge_data, + # Create node with properties + node_properties = {k: v for k, v in node_data.items()} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node), labels=[str(node)], properties=node_properties + ) ) - ) - seen_edges.add(edge_id) + seen_nodes.add(str(node)) - # logger.info(result.edges) + # Add edges to result + for edge in subgraph.edges(): + source, target = edge + edge_id = f"{source}-{target}" + if edge_id in seen_edges: + continue + + edge_data = dict(subgraph.edges[edge]) + + # Create edge with complete information + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source), + target=str(target), + properties=edge_data, + ) + ) + seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py new file mode 100644 index 00000000..9de3bb79 --- /dev/null +++ b/lightrag/kg/shared_storage.py @@ -0,0 +1,94 @@ +from multiprocessing.synchronize import Lock as ProcessLock +from threading import Lock as ThreadLock +from multiprocessing import Manager +from typing import Any, Dict, Optional, Union + +# 定义类型变量 +LockType = Union[ProcessLock, ThreadLock] + +# 全局变量 +_shared_data: Optional[Dict[str, Any]] = None +_namespace_objects: Optional[Dict[str, Any]] = None +_global_lock: Optional[LockType] = None +is_multiprocess = False +manager = None + +def initialize_manager(): + """Initialize manager, only for multiple processes where workers > 1""" + global manager + if manager is None: + manager = Manager() + +def _get_global_lock() -> LockType: + global _global_lock, is_multiprocess + + if _global_lock is None: + if is_multiprocess: + _global_lock = manager.Lock() + else: + _global_lock = ThreadLock() + + return _global_lock + +def get_storage_lock() -> LockType: + """return storage lock for data consistency""" + return _get_global_lock() + +def get_scan_lock() -> LockType: + """return scan_progress lock for data consistency""" + return get_storage_lock() + +def get_shared_data() -> Dict[str, Any]: + """ + return shared data for all storage types + create mult-process save share data only if need for better performance + """ + global _shared_data, is_multiprocess + + if _shared_data is None: + lock = _get_global_lock() + with lock: + if _shared_data is None: + if is_multiprocess: + _shared_data = manager.dict() + else: + _shared_data = {} + + return _shared_data + +def get_namespace_object(namespace: str) -> Any: + """Get an object for specific namespace""" + global _namespace_objects, is_multiprocess + + if _namespace_objects is None: + lock = _get_global_lock() + with lock: + if _namespace_objects is None: + _namespace_objects = {} + + if namespace not in _namespace_objects: + lock = _get_global_lock() + with lock: + if namespace not in _namespace_objects: + if is_multiprocess: + _namespace_objects[namespace] = manager.Value('O', None) + else: + _namespace_objects[namespace] = None + + return _namespace_objects[namespace] + +def get_namespace_data(namespace: str) -> Dict[str, Any]: + """get storage space for specific storage type(namespace)""" + shared_data = get_shared_data() + lock = _get_global_lock() + + if namespace not in shared_data: + with lock: + if namespace not in shared_data: + shared_data[namespace] = {} + + return shared_data[namespace] + +def get_scan_progress() -> Dict[str, Any]: + """get storage space for document scanning progress data""" + return get_namespace_data('scan_progress') diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index c115b33a..d7da6017 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -266,13 +266,7 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) - def __post_init__(self): - # Initialize manager if needed - from lightrag.api.utils_api import manager, initialize_manager - if manager is None: - initialize_manager() - logger.info("Initialized manager for single process mode") - + def __post_init__(self): os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}")