Refactor storage implementations to support both single and multi-process modes

• Add shared storage management module
• Support process/thread lock based on mode
This commit is contained in:
yangdx
2025-02-26 05:38:38 +08:00
parent 8050b0f91b
commit 2752a764ae
10 changed files with 608 additions and 623 deletions

View File

@@ -406,9 +406,6 @@ def create_app(args):
def get_application(): def get_application():
"""Factory function for creating the FastAPI application""" """Factory function for creating the FastAPI application"""
from .utils_api import initialize_manager
initialize_manager()
# Get args from environment variable # Get args from environment variable
args_json = os.environ.get('LIGHTRAG_ARGS') args_json = os.environ.get('LIGHTRAG_ARGS')
if not args_json: if not args_json:
@@ -428,6 +425,12 @@ def main():
# Save args to environment variable for child processes # Save args to environment variable for child processes
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
if args.workers > 1:
from lightrag.kg.shared_storage import initialize_manager
initialize_manager()
import lightrag.kg.shared_storage as shared_storage
shared_storage.is_multiprocess = True
# Configure uvicorn logging # Configure uvicorn logging
logging.config.dictConfig({ logging.config.dictConfig({
"version": 1, "version": 1,

View File

@@ -18,12 +18,10 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import ( from ..utils_api import get_api_key_dependency
get_api_key_dependency, from lightrag.kg.shared_storage import (
scan_progress, get_scan_progress,
update_scan_progress_if_not_scanning, get_scan_lock,
update_scan_progress,
reset_scan_progress,
) )
@@ -378,23 +376,51 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
"""Background task to scan and index documents""" """Background task to scan and index documents"""
if not update_scan_progress_if_not_scanning(): scan_progress = get_scan_progress()
ASCIIColors.info( scan_lock = get_scan_lock()
"Skip document scanning(another scanning is active)"
) with scan_lock:
return if scan_progress["is_scanning"]:
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
return
scan_progress.update({
"is_scanning": True,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
try: try:
new_files = doc_manager.scan_directory_for_new_files() new_files = doc_manager.scan_directory_for_new_files()
total_files = len(new_files) total_files = len(new_files)
update_scan_progress("", total_files, 0) # Initialize progress scan_progress.update({
"current_file": "",
"total_files": total_files,
"indexed_count": 0,
"progress": 0,
})
logging.info(f"Found {total_files} new files to index.") logging.info(f"Found {total_files} new files to index.")
for idx, file_path in enumerate(new_files): for idx, file_path in enumerate(new_files):
try: try:
update_scan_progress(os.path.basename(file_path), total_files, idx) progress = (idx / total_files * 100) if total_files > 0 else 0
scan_progress.update({
"current_file": os.path.basename(file_path),
"indexed_count": idx,
"progress": progress,
})
await pipeline_index_file(rag, file_path) await pipeline_index_file(rag, file_path)
update_scan_progress(os.path.basename(file_path), total_files, idx + 1)
progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
scan_progress.update({
"current_file": os.path.basename(file_path),
"indexed_count": idx + 1,
"progress": progress,
})
except Exception as e: except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}") logging.error(f"Error indexing file {file_path}: {str(e)}")
@@ -402,7 +428,13 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
except Exception as e: except Exception as e:
logging.error(f"Error during scanning process: {str(e)}") logging.error(f"Error during scanning process: {str(e)}")
finally: finally:
reset_scan_progress() scan_progress.update({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
def create_document_routes( def create_document_routes(
@@ -427,7 +459,7 @@ def create_document_routes(
return {"status": "scanning_started"} return {"status": "scanning_started"}
@router.get("/scan-progress") @router.get("/scan-progress")
async def get_scan_progress(): async def get_scanning_progress():
""" """
Get the current progress of the document scanning process. Get the current progress of the document scanning process.
@@ -439,7 +471,7 @@ def create_document_routes(
- total_files: Total number of files to process - total_files: Total number of files to process
- progress: Percentage of completion - progress: Percentage of completion
""" """
return dict(scan_progress) return dict(get_scan_progress())
@router.post("/upload", dependencies=[Depends(optional_api_key)]) @router.post("/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir( async def upload_to_input_dir(

View File

@@ -6,7 +6,6 @@ import os
import argparse import argparse
from typing import Optional from typing import Optional
import sys import sys
from multiprocessing import Manager
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from fastapi import HTTPException, Security from fastapi import HTTPException, Security
@@ -17,66 +16,6 @@ from starlette.status import HTTP_403_FORBIDDEN
# Load environment variables # Load environment variables
load_dotenv(override=True) load_dotenv(override=True)
# Global variables for manager and shared state
manager = None
scan_progress = None
scan_lock = None
def initialize_manager():
"""Initialize manager and shared state for cross-process communication"""
global manager, scan_progress, scan_lock
if manager is None:
manager = Manager()
scan_progress = manager.dict({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
scan_lock = manager.Lock()
def update_scan_progress_if_not_scanning():
"""
Atomically check if scanning is not in progress and update scan_progress if it's not.
Returns True if the update was successful, False if scanning was already in progress.
"""
with scan_lock:
if not scan_progress["is_scanning"]:
scan_progress.update({
"is_scanning": True,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
return True
return False
def update_scan_progress(current_file: str, total_files: int, indexed_count: int):
"""
Atomically update scan progress information.
"""
progress = (indexed_count / total_files * 100) if total_files > 0 else 0
scan_progress.update({
"current_file": current_file,
"indexed_count": indexed_count,
"total_files": total_files,
"progress": progress,
})
def reset_scan_progress():
"""
Atomically reset scan progress to initial state.
"""
scan_progress.update({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
class OllamaServerInfos: class OllamaServerInfos:
# Constants for emulated Ollama model information # Constants for emulated Ollama model information

View File

@@ -2,48 +2,21 @@ import os
import time import time
import asyncio import asyncio
from typing import Any, final from typing import Any, final
import threading
import json import json
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
from lightrag.api.utils_api import manager as main_process_manager
from lightrag.utils import ( from lightrag.utils import logger,compute_mdhash_id
logger, from lightrag.base import BaseVectorStorage
compute_mdhash_id, from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
)
from lightrag.base import (
BaseVectorStorage,
)
if not pm.is_installed("faiss"): if not pm.is_installed("faiss"):
pm.install("faiss") pm.install("faiss")
import faiss # type: ignore import faiss # type: ignore
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_indices = None
_shared_meta = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_indices, _shared_meta
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_indices = _manager.dict()
_shared_meta = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final @final
@dataclass @dataclass
@@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function # Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock()
# Ensure manager is initialized self._index = get_namespace_object('faiss_indices')
_get_manager() self._id_to_meta = get_namespace_data('faiss_meta')
# Get or create namespace index and metadata with self._storage_lock:
if self.namespace not in _shared_indices: if is_multiprocess:
with _init_lock: if self._index.value is None:
if self.namespace not in _shared_indices: # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
try: # If you have a large number of vectors, you might want IVF or other indexes.
# Create an empty Faiss index for inner product # For demonstration, we use a simple IndexFlatIP.
index = faiss.IndexFlatIP(self._dim) self._index.value = faiss.IndexFlatIP(self._dim)
meta = {} else:
if self._index is None:
# Load existing index if available self._index = faiss.IndexFlatIP(self._dim)
if os.path.exists(self._faiss_index_file):
try: # Keep a local store for metadata, IDs, etc.
index = faiss.read_index(self._faiss_index_file) # Maps <int faiss_id> → metadata (including your original ID).
with open(self._meta_file, "r", encoding="utf-8") as f: self._id_to_meta.update({})
stored_dict = json.load(f)
# Convert string keys back to int # Attempt to load an existing index + metadata from disk
meta = {int(k): v for k, v in stored_dict.items()} self._load_faiss_index()
logger.info(
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
index = faiss.IndexFlatIP(self._dim)
meta = {}
_shared_indices[self.namespace] = index
_shared_meta[self.namespace] = meta
except Exception as e:
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
raise RuntimeError(f"Faiss index initialization failed: {e}")
try:
self._index = _shared_indices[self.namespace]
self._id_to_meta = _shared_meta[self.namespace]
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
@@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Normalize embeddings for cosine similarity (in-place) # Normalize embeddings for cosine similarity (in-place)
faiss.normalize_L2(embeddings) faiss.normalize_L2(embeddings)
# Upsert logic: with self._storage_lock:
# 1. Identify which vectors to remove if they exist # Upsert logic:
# 2. Remove them # 1. Identify which vectors to remove if they exist
# 3. Add the new vectors # 2. Remove them
existing_ids_to_remove = [] # 3. Add the new vectors
for meta, emb in zip(list_data, embeddings): existing_ids_to_remove = []
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) for meta, emb in zip(list_data, embeddings):
if faiss_internal_id is not None: faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
existing_ids_to_remove.append(faiss_internal_id) if faiss_internal_id is not None:
existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove: if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove) self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors # Step 2: Add new vectors
start_idx = self._index.ntotal start_idx = (self._index.value if is_multiprocess else self._index).ntotal
self._index.add(embeddings) if is_multiprocess:
self._index.value.add(embeddings)
else:
self._index.add(embeddings)
# Step 3: Store metadata + vector for each new ID # Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data): for i, meta in enumerate(list_data):
fid = start_idx + i fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed # Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist() meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta[fid] = meta self._id_to_meta.update({fid: meta})
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
""" """
@@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
# Perform the similarity search # Perform the similarity search
distances, indices = self._index.search(embedding, top_k) with self._storage_lock:
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
distances = distances[0] distances = distances[0]
indices = indices[0] indices = indices[0]
results = [] results = []
for dist, idx in zip(distances, indices): for dist, idx in zip(distances, indices):
if idx == -1: if idx == -1:
# Faiss returns -1 if no neighbor # Faiss returns -1 if no neighbor
continue continue
# Cosine similarity threshold # Cosine similarity threshold
if dist < self.cosine_better_than_threshold: if dist < self.cosine_better_than_threshold:
continue continue
meta = self._id_to_meta.get(idx, {}) meta = self._id_to_meta.get(idx, {})
results.append( results.append(
{ {
**meta, **meta,
"id": meta.get("__id__"), "id": meta.get("__id__"),
"distance": float(dist), "distance": float(dist),
"created_at": meta.get("__created_at__"), "created_at": meta.get("__created_at__"),
} }
) )
return results return results
@property @property
def client_storage(self): def client_storage(self):
# Return whatever structure LightRAG might need for debugging # Return whatever structure LightRAG might need for debugging
return {"data": list(self._id_to_meta.values())} with self._storage_lock:
return {"data": list(self._id_to_meta.values())}
async def delete(self, ids: list[str]): async def delete(self, ids: list[str]):
""" """
Delete vectors for the provided custom IDs. Delete vectors for the provided custom IDs.
""" """
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
to_remove = [] with self._storage_lock:
for cid in ids: to_remove = []
fid = self._find_faiss_id_by_custom_id(cid) for cid in ids:
if fid is not None: fid = self._find_faiss_id_by_custom_id(cid)
to_remove.append(fid) if fid is not None:
to_remove.append(fid)
if to_remove: if to_remove:
self._remove_faiss_ids(to_remove) self._remove_faiss_ids(to_remove)
logger.info( logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
) )
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
@@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
Delete relations for a given entity by scanning metadata. Delete relations for a given entity by scanning metadata.
""" """
logger.debug(f"Searching relations for entity {entity_name}") logger.debug(f"Searching relations for entity {entity_name}")
relations = [] with self._storage_lock:
for fid, meta in self._id_to_meta.items(): relations = []
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: for fid, meta in self._id_to_meta.items():
relations.append(fid) if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
relations.append(fid)
logger.debug(f"Found {len(relations)} relations for {entity_name}") logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations: if relations:
self._remove_faiss_ids(relations) self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}") logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
self._save_faiss_index() with self._storage_lock:
self._save_faiss_index()
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
# Internal helper methods # Internal helper methods
@@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
""" """
Return the Faiss internal ID for a given custom ID, or None if not found. Return the Faiss internal ID for a given custom ID, or None if not found.
""" """
for fid, meta in self._id_to_meta.items(): with self._storage_lock:
if meta.get("__id__") == custom_id: for fid, meta in self._id_to_meta.items():
return fid if meta.get("__id__") == custom_id:
return None return fid
return None
def _remove_faiss_ids(self, fid_list): def _remove_faiss_ids(self, fid_list):
""" """
@@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage):
Because IndexFlatIP doesn't support 'removals', Because IndexFlatIP doesn't support 'removals',
we rebuild the index excluding those vectors. we rebuild the index excluding those vectors.
""" """
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] with self._storage_lock:
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
# Rebuild the index # Rebuild the index
vectors_to_keep = [] vectors_to_keep = []
new_id_to_meta = {} new_id_to_meta = {}
for new_fid, old_fid in enumerate(keep_fids): for new_fid, old_fid in enumerate(keep_fids):
vec_meta = self._id_to_meta[old_fid] vec_meta = self._id_to_meta[old_fid]
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta new_id_to_meta[new_fid] = vec_meta
# Re-init index # Re-init index
self._index = faiss.IndexFlatIP(self._dim) new_index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep: if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32) arr = np.array(vectors_to_keep, dtype=np.float32)
self._index.add(arr) new_index.add(arr)
if is_multiprocess:
self._index.value = new_index
else:
self._index = new_index
self._id_to_meta = new_id_to_meta self._id_to_meta.update(new_id_to_meta)
def _save_faiss_index(self): def _save_faiss_index(self):
""" """
Save the current Faiss index + metadata to disk so it can persist across runs. Save the current Faiss index + metadata to disk so it can persist across runs.
""" """
faiss.write_index(self._index, self._faiss_index_file) with self._storage_lock:
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
# Save metadata dict to JSON. Convert all keys to strings for JSON storage. # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
# We'll keep the int -> dict, but JSON requires string keys. # We'll keep the int -> dict, but JSON requires string keys.
serializable_dict = {} serializable_dict = {}
for fid, meta in self._id_to_meta.items(): for fid, meta in self._id_to_meta.items():
serializable_dict[str(fid)] = meta serializable_dict[str(fid)] = meta
with open(self._meta_file, "w", encoding="utf-8") as f: with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f) json.dump(serializable_dict, f)
def _load_faiss_index(self): def _load_faiss_index(self):
""" """
@@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
try: try:
# Load the Faiss index # Load the Faiss index
self._index = faiss.read_index(self._faiss_index_file) loaded_index = faiss.read_index(self._faiss_index_file)
if is_multiprocess:
self._index.value = loaded_index
else:
self._index = loaded_index
# Load metadata # Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f: with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f) stored_dict = json.load(f)
# Convert string keys back to int # Convert string keys back to int
self._id_to_meta = {} self._id_to_meta.update({})
for fid_str, meta in stored_dict.items(): for fid_str, meta in stored_dict.items():
fid = int(fid_str) fid = int(fid_str)
self._id_to_meta[fid] = meta self._id_to_meta[fid] = meta
logger.info( logger.info(
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}") logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.") logger.warning("Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim) new_index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} if is_multiprocess:
self._index.value = new_index
else:
self._index = new_index
self._id_to_meta.update({})

View File

@@ -1,7 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
import os import os
from typing import Any, Union, final from typing import Any, Union, final
import threading
from lightrag.base import ( from lightrag.base import (
DocProcessingStatus, DocProcessingStatus,
@@ -13,26 +12,7 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from lightrag.api.utils_api import manager as main_process_manager from .shared_storage import get_namespace_data, get_storage_lock
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_doc_status_data = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_doc_status_data
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_doc_status_data = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final @final
@@ -43,45 +23,32 @@ class JsonDocStatusStorage(DocStatusStorage):
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock()
# Ensure manager is initialized self._data = get_namespace_data(self.namespace)
_get_manager() with self._storage_lock:
self._data.update(load_json(self._file_name) or {})
# Get or create namespace data logger.info(f"Loaded document status storage with {len(self._data)} records")
if self.namespace not in _shared_doc_status_data:
with _init_lock:
if self.namespace not in _shared_doc_status_data:
try:
initial_data = load_json(self._file_name) or {}
_shared_doc_status_data[self.namespace] = initial_data
except Exception as e:
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
raise RuntimeError(f"Shared data initialization failed: {e}")
try:
self._data = _shared_doc_status_data[self.namespace]
logger.info(f"Loaded document status storage with {len(self._data)} records")
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)""" """Return keys that should be processed (not in storage or not successfully processed)"""
return set(keys) - set(self._data.keys()) with self._storage_lock:
return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] result: list[dict[str, Any]] = []
for id in ids: with self._storage_lock:
data = self._data.get(id, None) for id in ids:
if data: data = self._data.get(id, None)
result.append(data) if data:
result.append(data)
return result return result
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus} counts = {status.value: 0 for status in DocStatus}
for doc in self._data.values(): with self._storage_lock:
counts[doc["status"]] += 1 for doc in self._data.values():
counts[doc["status"]] += 1
return counts return counts
async def get_docs_by_status( async def get_docs_by_status(
@@ -89,39 +56,46 @@ class JsonDocStatusStorage(DocStatusStorage):
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
result = {} result = {}
for k, v in self._data.items(): with self._storage_lock:
if v["status"] == status.value: for k, v in self._data.items():
try: if v["status"] == status.value:
# Make a copy of the data to avoid modifying the original try:
data = v.copy() # Make a copy of the data to avoid modifying the original
# If content is missing, use content_summary as content data = v.copy()
if "content" not in data and "content_summary" in data: # If content is missing, use content_summary as content
data["content"] = data["content_summary"] if "content" not in data and "content_summary" in data:
result[k] = DocProcessingStatus(**data) data["content"] = data["content_summary"]
except KeyError as e: result[k] = DocProcessingStatus(**data)
logger.error(f"Missing required field for document {k}: {e}") except KeyError as e:
continue logger.error(f"Missing required field for document {k}: {e}")
continue
return result return result
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
write_json(self._data, self._file_name) # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
with self._storage_lock:
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
self._data.update(data) with self._storage_lock:
self._data.update(data)
await self.index_done_callback() await self.index_done_callback()
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.get(id) with self._storage_lock:
return self._data.get(id)
async def delete(self, doc_ids: list[str]): async def delete(self, doc_ids: list[str]):
for doc_id in doc_ids: with self._storage_lock:
self._data.pop(doc_id, None) for doc_id in doc_ids:
self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
self._data.clear() with self._storage_lock:
self._data.clear()

View File

@@ -1,8 +1,6 @@
import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final from typing import Any, final
import threading
from lightrag.base import ( from lightrag.base import (
BaseKVStorage, BaseKVStorage,
@@ -12,26 +10,7 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from lightrag.api.utils_api import manager as main_process_manager from .shared_storage import get_namespace_data, get_storage_lock
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_kv_data = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_kv_data
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_kv_data = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final @final
@@ -39,57 +18,49 @@ def _get_manager():
class JsonKVStorage(BaseKVStorage): class JsonKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock()
self._lock = asyncio.Lock() self._data = get_namespace_data(self.namespace)
with self._storage_lock:
# Ensure manager is initialized if not self._data:
_get_manager() self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {}
# Get or create namespace data logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
if self.namespace not in _shared_kv_data:
with _init_lock:
if self.namespace not in _shared_kv_data:
try:
initial_data = load_json(self._file_name) or {}
_shared_kv_data[self.namespace] = initial_data
except Exception as e:
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
raise RuntimeError(f"Shared data initialization failed: {e}")
try:
self._data = _shared_kv_data[self.namespace]
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
write_json(self._data, self._file_name) # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
with self._storage_lock:
write_json(self._data, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return self._data.get(id) with self._storage_lock:
return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return [ with self._storage_lock:
( return [
{k: v for k, v in self._data[id].items()} (
if self._data.get(id, None) {k: v for k, v in self._data[id].items()}
else None if self._data.get(id, None)
) else None
for id in ids )
] for id in ids
]
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
return set(keys) - set(self._data.keys()) with self._storage_lock:
return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
left_data = {k: v for k, v in data.items() if k not in self._data} with self._storage_lock:
self._data.update(left_data) left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
for doc_id in ids: with self._storage_lock:
self._data.pop(doc_id, None) for doc_id in ids:
self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()

View File

@@ -3,50 +3,29 @@ import os
from typing import Any, final from typing import Any, final
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import threading
import time import time
from lightrag.utils import ( from lightrag.utils import (
logger, logger,
compute_mdhash_id, compute_mdhash_id,
) )
from lightrag.api.utils_api import manager as main_process_manager
import pipmaster as pm import pipmaster as pm
from lightrag.base import ( from lightrag.base import BaseVectorStorage
BaseVectorStorage, from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
)
if not pm.is_installed("nano-vectordb"): if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb") pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_vector_clients = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_vector_clients
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_vector_clients = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final @final
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize lock only for file operations
self._save_lock = asyncio.Lock() self._storage_lock = get_storage_lock()
# Use global config value if specified, otherwise use default # Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -61,28 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Ensure manager is initialized self._client = get_namespace_object(self.namespace)
_get_manager()
# Get or create namespace client with self._storage_lock:
if self.namespace not in _shared_vector_clients: if is_multiprocess:
with _init_lock: if self._client.value is None:
if self.namespace not in _shared_vector_clients: self._client.value = NanoVectorDB(
try: self.embedding_func.embedding_dim, storage_file=self._client_file_name
client = NanoVectorDB( )
self.embedding_func.embedding_dim, else:
storage_file=self._client_file_name if self._client is None:
) self._client = NanoVectorDB(
_shared_vector_clients[self.namespace] = client self.embedding_func.embedding_dim, storage_file=self._client_file_name
except Exception as e: )
logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
raise RuntimeError(f"Vector DB client initialization failed: {e}")
try: logger.info(f"Initialized vector DB client for namespace {self.namespace}")
self._client = _shared_vector_clients[self.namespace]
except Exception as e: def _get_client(self):
logger.error(f"Failed to access shared memory: {e}") """Get the appropriate client instance based on multiprocess mode"""
raise RuntimeError(f"Cannot access shared memory: {e}") if is_multiprocess:
return self._client.value
return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
@@ -104,6 +82,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
# Execute embedding outside of lock to avoid long lock times
embedding_tasks = [self.embedding_func(batch) for batch in batches] embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = await asyncio.gather(*embedding_tasks) embeddings_list = await asyncio.gather(*embedding_tasks)
@@ -111,7 +90,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
if len(embeddings) == len(list_data): if len(embeddings) == len(list_data):
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data) with self._storage_lock:
client = self._get_client()
results = client.upsert(datas=list_data)
return results return results
else: else:
# sometimes the embedding is not returned correctly. just log it. # sometimes the embedding is not returned correctly. just log it.
@@ -120,27 +101,32 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
results = self._client.query(
query=embedding, with self._storage_lock:
top_k=top_k, client = self._get_client()
better_than_threshold=self.cosine_better_than_threshold, results = client.query(
) query=embedding,
results = [ top_k=top_k,
{ better_than_threshold=self.cosine_better_than_threshold,
**dp, )
"id": dp["__id__"], results = [
"distance": dp["__metrics__"], {
"created_at": dp.get("__created_at__"), **dp,
} "id": dp["__id__"],
for dp in results "distance": dp["__metrics__"],
] "created_at": dp.get("__created_at__"),
}
for dp in results
]
return results return results
@property @property
def client_storage(self): def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage") client = self._get_client()
return getattr(client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]): async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs """Delete vectors with specified IDs
@@ -149,8 +135,10 @@ class NanoVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
try: try:
self._client.delete(ids) with self._storage_lock:
logger.info( client = self._get_client()
client.delete(ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}" f"Successfully deleted {len(ids)} vectors from {self.namespace}"
) )
except Exception as e: except Exception as e:
@@ -162,35 +150,42 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.debug( logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}" f"Attempting to delete entity {entity_name} with ID {entity_id}"
) )
# Check if the entity exists
if self._client.get([entity_id]): with self._storage_lock:
await self.delete([entity_id]) client = self._get_client()
logger.debug(f"Successfully deleted entity {entity_name}") # Check if the entity exists
else: if client.get([entity_id]):
logger.debug(f"Entity {entity_name} not found in storage") client.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e: except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}") logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
try: try:
relations = [ with self._storage_lock:
dp client = self._get_client()
for dp in self.client_storage["data"] storage = getattr(client, "_NanoVectorDB__storage")
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name relations = [
] dp
logger.debug(f"Found {len(relations)} relations for entity {entity_name}") for dp in storage["data"]
ids_to_delete = [relation["__id__"] for relation in relations] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete: if ids_to_delete:
await self.delete(ids_to_delete) client.delete(ids_to_delete)
logger.debug( logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}" f"Deleted {len(ids_to_delete)} relations for {entity_name}"
) )
else: else:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
async with self._save_lock: with self._storage_lock:
self._client.save() client = self._get_client()
client.save()

View File

@@ -1,18 +1,13 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final from typing import Any, final
import threading
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import ( from lightrag.utils import logger
logger, from lightrag.base import BaseGraphStorage
) from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
from lightrag.api.utils_api import manager as main_process_manager
from lightrag.base import (
BaseGraphStorage,
)
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("networkx"): if not pm.is_installed("networkx"):
@@ -24,25 +19,6 @@ if not pm.is_installed("graspologic"):
import networkx as nx import networkx as nx
from graspologic import embed from graspologic import embed
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_graphs = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_graphs
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_graphs = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final @final
@dataclass @dataclass
@@ -97,76 +73,98 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join( self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml" self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
) )
self._storage_lock = get_storage_lock()
# Ensure manager is initialized self._graph = get_namespace_object(self.namespace)
_get_manager() with self._storage_lock:
if is_multiprocess:
# Get or create namespace graph if self._graph.value is None:
if self.namespace not in _shared_graphs: preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
with _init_lock: self._graph.value = preloaded_graph or nx.Graph()
if self.namespace not in _shared_graphs: logger.info(
try: f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) )
if preloaded_graph is not None: else:
logger.info( if self._graph is None:
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
) self._graph = preloaded_graph or nx.Graph()
_shared_graphs[self.namespace] = preloaded_graph or nx.Graph() logger.info(
except Exception as e: f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}") )
raise RuntimeError(f"Graph initialization failed: {e}")
self._node_embed_algorithms = {
try:
self._graph = _shared_graphs[self.namespace]
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
except Exception as e:
logger.error(f"Failed to access shared memory: {e}") def _get_graph(self):
raise RuntimeError(f"Cannot access shared memory: {e}") """Get the appropriate graph instance based on multiprocess mode"""
if is_multiprocess:
return self._graph.value
return self._graph
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) with self._storage_lock:
graph = self._get_graph()
NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id) with self._storage_lock:
graph = self._get_graph()
return graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id) with self._storage_lock:
graph = self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id) with self._storage_lock:
graph = self._get_graph()
return graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id) with self._storage_lock:
graph = self._get_graph()
return graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id) with self._storage_lock:
graph = self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id)) with self._storage_lock:
graph = self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id): with self._storage_lock:
return list(self._graph.edges(source_node_id)) graph = self._get_graph()
return None if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data) with self._storage_lock:
graph = self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data) with self._storage_lock:
graph = self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
if self._graph.has_node(node_id): with self._storage_lock:
self._graph.remove_node(node_id) graph = self._get_graph()
logger.info(f"Node {node_id} deleted from the graph.") if graph.has_node(node_id):
else: graph.remove_node(node_id)
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.debug(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
@@ -175,14 +173,15 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED # TODO: NOT USED
async def _node2vec_embed(self): async def _node2vec_embed(self):
embeddings, nodes = embed.node2vec_embed( with self._storage_lock:
self._graph, graph = self._get_graph()
**self.global_config["node2vec_params"], embeddings, nodes = embed.node2vec_embed(
) graph,
**self.global_config["node2vec_params"],
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] )
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]): def remove_nodes(self, nodes: list[str]):
@@ -191,9 +190,11 @@ class NetworkXStorage(BaseGraphStorage):
Args: Args:
nodes: List of node IDs to be deleted nodes: List of node IDs to be deleted
""" """
for node in nodes: with self._storage_lock:
if self._graph.has_node(node): graph = self._get_graph()
self._graph.remove_node(node) for node in nodes:
if graph.has_node(node):
graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]): def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges """Delete multiple edges
@@ -201,9 +202,11 @@ class NetworkXStorage(BaseGraphStorage):
Args: Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple edges: List of edges to be deleted, each edge is a (source, target) tuple
""" """
for source, target in edges: with self._storage_lock:
if self._graph.has_edge(source, target): graph = self._get_graph()
self._graph.remove_edge(source, target) for source, target in edges:
if graph.has_edge(source, target):
graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
@@ -211,9 +214,11 @@ class NetworkXStorage(BaseGraphStorage):
Returns: Returns:
[label1, label2, ...] # Alphabetically sorted label list [label1, label2, ...] # Alphabetically sorted label list
""" """
labels = set() with self._storage_lock:
for node in self._graph.nodes(): graph = self._get_graph()
labels.add(str(node)) # Add node id as a label labels = set()
for node in graph.nodes():
labels.add(str(node)) # Add node id as a label
# Return sorted list # Return sorted list
return sorted(list(labels)) return sorted(list(labels))
@@ -235,87 +240,86 @@ class NetworkXStorage(BaseGraphStorage):
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
# Handle special case for "*" label with self._storage_lock:
if node_label == "*": graph = self._get_graph()
# For "*", return the entire graph including all nodes and edges
subgraph = ( # Handle special case for "*" label
self._graph.copy() if node_label == "*":
) # Create a copy to avoid modifying the original graph # For "*", return the entire graph including all nodes and edges
else: subgraph = graph.copy() # Create a copy to avoid modifying the original graph
# Find nodes with matching node id (partial match) else:
nodes_to_explore = [] # Find nodes with matching node id (partial match)
for n, attr in self._graph.nodes(data=True): nodes_to_explore = []
if node_label in str(n): # Use partial matching for n, attr in graph.nodes(data=True):
nodes_to_explore.append(n) if node_label in str(n): # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore: if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}") logger.warning(f"No nodes found with label {node_label}")
return result return result
# Get subgraph using ego_graph # Get subgraph using ego_graph
subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth) subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500 max_graph_nodes = 500
if len(subgraph.nodes()) > max_graph_nodes: if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes :max_graph_nodes
] ]
top_node_ids = [node[0] for node in top_nodes] top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes # Create new subgraph with only top nodes
subgraph = subgraph.subgraph(top_node_ids) subgraph = subgraph.subgraph(top_node_ids)
logger.info( logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
)
# Add nodes to result
for node in subgraph.nodes():
if str(node) in seen_nodes:
continue
node_data = dict(subgraph.nodes[node])
# Get entity_type as labels
labels = []
if "entity_type" in node_data:
if isinstance(node_data["entity_type"], list):
labels.extend(node_data["entity_type"])
else:
labels.append(node_data["entity_type"])
# Create node with properties
node_properties = {k: v for k, v in node_data.items()}
result.nodes.append(
KnowledgeGraphNode(
id=str(node), labels=[str(node)], properties=node_properties
) )
)
seen_nodes.add(str(node))
# Add edges to result # Add nodes to result
for edge in subgraph.edges(): for node in subgraph.nodes():
source, target = edge if str(node) in seen_nodes:
edge_id = f"{source}-{target}" continue
if edge_id in seen_edges:
continue
edge_data = dict(subgraph.edges[edge]) node_data = dict(subgraph.nodes[node])
# Get entity_type as labels
labels = []
if "entity_type" in node_data:
if isinstance(node_data["entity_type"], list):
labels.extend(node_data["entity_type"])
else:
labels.append(node_data["entity_type"])
# Create edge with complete information # Create node with properties
result.edges.append( node_properties = {k: v for k, v in node_data.items()}
KnowledgeGraphEdge(
id=edge_id, result.nodes.append(
type="DIRECTED", KnowledgeGraphNode(
source=str(source), id=str(node), labels=[str(node)], properties=node_properties
target=str(target), )
properties=edge_data,
) )
) seen_nodes.add(str(node))
seen_edges.add(edge_id)
# logger.info(result.edges) # Add edges to result
for edge in subgraph.edges():
source, target = edge
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue
edge_data = dict(subgraph.edges[edge])
# Create edge with complete information
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source),
target=str(target),
properties=edge_data,
)
)
seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"

View File

@@ -0,0 +1,94 @@
from multiprocessing.synchronize import Lock as ProcessLock
from threading import Lock as ThreadLock
from multiprocessing import Manager
from typing import Any, Dict, Optional, Union
# 定义类型变量
LockType = Union[ProcessLock, ThreadLock]
# 全局变量
_shared_data: Optional[Dict[str, Any]] = None
_namespace_objects: Optional[Dict[str, Any]] = None
_global_lock: Optional[LockType] = None
is_multiprocess = False
manager = None
def initialize_manager():
"""Initialize manager, only for multiple processes where workers > 1"""
global manager
if manager is None:
manager = Manager()
def _get_global_lock() -> LockType:
global _global_lock, is_multiprocess
if _global_lock is None:
if is_multiprocess:
_global_lock = manager.Lock()
else:
_global_lock = ThreadLock()
return _global_lock
def get_storage_lock() -> LockType:
"""return storage lock for data consistency"""
return _get_global_lock()
def get_scan_lock() -> LockType:
"""return scan_progress lock for data consistency"""
return get_storage_lock()
def get_shared_data() -> Dict[str, Any]:
"""
return shared data for all storage types
create mult-process save share data only if need for better performance
"""
global _shared_data, is_multiprocess
if _shared_data is None:
lock = _get_global_lock()
with lock:
if _shared_data is None:
if is_multiprocess:
_shared_data = manager.dict()
else:
_shared_data = {}
return _shared_data
def get_namespace_object(namespace: str) -> Any:
"""Get an object for specific namespace"""
global _namespace_objects, is_multiprocess
if _namespace_objects is None:
lock = _get_global_lock()
with lock:
if _namespace_objects is None:
_namespace_objects = {}
if namespace not in _namespace_objects:
lock = _get_global_lock()
with lock:
if namespace not in _namespace_objects:
if is_multiprocess:
_namespace_objects[namespace] = manager.Value('O', None)
else:
_namespace_objects[namespace] = None
return _namespace_objects[namespace]
def get_namespace_data(namespace: str) -> Dict[str, Any]:
"""get storage space for specific storage type(namespace)"""
shared_data = get_shared_data()
lock = _get_global_lock()
if namespace not in shared_data:
with lock:
if namespace not in shared_data:
shared_data[namespace] = {}
return shared_data[namespace]
def get_scan_progress() -> Dict[str, Any]:
"""get storage space for document scanning progress data"""
return get_namespace_data('scan_progress')

View File

@@ -266,13 +266,7 @@ class LightRAG:
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self): def __post_init__(self):
# Initialize manager if needed
from lightrag.api.utils_api import manager, initialize_manager
if manager is None:
initialize_manager()
logger.info("Initialized manager for single process mode")
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path, self.log_level) set_logger(self.log_file_path, self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")