Refactor storage implementations to support both single and multi-process modes

• Add shared storage management module
• Support process/thread lock based on mode
This commit is contained in:
yangdx
2025-02-26 05:38:38 +08:00
parent 8050b0f91b
commit 2752a764ae
10 changed files with 608 additions and 623 deletions

View File

@@ -406,9 +406,6 @@ def create_app(args):
def get_application():
"""Factory function for creating the FastAPI application"""
from .utils_api import initialize_manager
initialize_manager()
# Get args from environment variable
args_json = os.environ.get('LIGHTRAG_ARGS')
if not args_json:
@@ -428,6 +425,12 @@ def main():
# Save args to environment variable for child processes
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
if args.workers > 1:
from lightrag.kg.shared_storage import initialize_manager
initialize_manager()
import lightrag.kg.shared_storage as shared_storage
shared_storage.is_multiprocess = True
# Configure uvicorn logging
logging.config.dictConfig({
"version": 1,

View File

@@ -18,12 +18,10 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import (
get_api_key_dependency,
scan_progress,
update_scan_progress_if_not_scanning,
update_scan_progress,
reset_scan_progress,
from ..utils_api import get_api_key_dependency
from lightrag.kg.shared_storage import (
get_scan_progress,
get_scan_lock,
)
@@ -378,23 +376,51 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
"""Background task to scan and index documents"""
if not update_scan_progress_if_not_scanning():
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
return
scan_progress = get_scan_progress()
scan_lock = get_scan_lock()
with scan_lock:
if scan_progress["is_scanning"]:
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
return
scan_progress.update({
"is_scanning": True,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
try:
new_files = doc_manager.scan_directory_for_new_files()
total_files = len(new_files)
update_scan_progress("", total_files, 0) # Initialize progress
scan_progress.update({
"current_file": "",
"total_files": total_files,
"indexed_count": 0,
"progress": 0,
})
logging.info(f"Found {total_files} new files to index.")
for idx, file_path in enumerate(new_files):
try:
update_scan_progress(os.path.basename(file_path), total_files, idx)
progress = (idx / total_files * 100) if total_files > 0 else 0
scan_progress.update({
"current_file": os.path.basename(file_path),
"indexed_count": idx,
"progress": progress,
})
await pipeline_index_file(rag, file_path)
update_scan_progress(os.path.basename(file_path), total_files, idx + 1)
progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
scan_progress.update({
"current_file": os.path.basename(file_path),
"indexed_count": idx + 1,
"progress": progress,
})
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
@@ -402,7 +428,13 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
reset_scan_progress()
scan_progress.update({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
def create_document_routes(
@@ -427,7 +459,7 @@ def create_document_routes(
return {"status": "scanning_started"}
@router.get("/scan-progress")
async def get_scan_progress():
async def get_scanning_progress():
"""
Get the current progress of the document scanning process.
@@ -439,7 +471,7 @@ def create_document_routes(
- total_files: Total number of files to process
- progress: Percentage of completion
"""
return dict(scan_progress)
return dict(get_scan_progress())
@router.post("/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(

View File

@@ -6,7 +6,6 @@ import os
import argparse
from typing import Optional
import sys
from multiprocessing import Manager
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security
@@ -17,66 +16,6 @@ from starlette.status import HTTP_403_FORBIDDEN
# Load environment variables
load_dotenv(override=True)
# Global variables for manager and shared state
manager = None
scan_progress = None
scan_lock = None
def initialize_manager():
"""Initialize manager and shared state for cross-process communication"""
global manager, scan_progress, scan_lock
if manager is None:
manager = Manager()
scan_progress = manager.dict({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
scan_lock = manager.Lock()
def update_scan_progress_if_not_scanning():
"""
Atomically check if scanning is not in progress and update scan_progress if it's not.
Returns True if the update was successful, False if scanning was already in progress.
"""
with scan_lock:
if not scan_progress["is_scanning"]:
scan_progress.update({
"is_scanning": True,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
return True
return False
def update_scan_progress(current_file: str, total_files: int, indexed_count: int):
"""
Atomically update scan progress information.
"""
progress = (indexed_count / total_files * 100) if total_files > 0 else 0
scan_progress.update({
"current_file": current_file,
"indexed_count": indexed_count,
"total_files": total_files,
"progress": progress,
})
def reset_scan_progress():
"""
Atomically reset scan progress to initial state.
"""
scan_progress.update({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
class OllamaServerInfos:
# Constants for emulated Ollama model information

View File

@@ -2,48 +2,21 @@ import os
import time
import asyncio
from typing import Any, final
import threading
import json
import numpy as np
from dataclasses import dataclass
import pipmaster as pm
from lightrag.api.utils_api import manager as main_process_manager
from lightrag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.base import (
BaseVectorStorage,
)
from lightrag.utils import logger,compute_mdhash_id
from lightrag.base import BaseVectorStorage
from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
if not pm.is_installed("faiss"):
pm.install("faiss")
import faiss # type: ignore
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_indices = None
_shared_meta = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_indices, _shared_meta
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_indices = _manager.dict()
_shared_meta = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock()
# Ensure manager is initialized
_get_manager()
self._index = get_namespace_object('faiss_indices')
self._id_to_meta = get_namespace_data('faiss_meta')
# Get or create namespace index and metadata
if self.namespace not in _shared_indices:
with _init_lock:
if self.namespace not in _shared_indices:
try:
# Create an empty Faiss index for inner product
index = faiss.IndexFlatIP(self._dim)
meta = {}
# Load existing index if available
if os.path.exists(self._faiss_index_file):
try:
index = faiss.read_index(self._faiss_index_file)
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
# Convert string keys back to int
meta = {int(k): v for k, v in stored_dict.items()}
logger.info(
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
index = faiss.IndexFlatIP(self._dim)
meta = {}
_shared_indices[self.namespace] = index
_shared_meta[self.namespace] = meta
except Exception as e:
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
raise RuntimeError(f"Faiss index initialization failed: {e}")
try:
self._index = _shared_indices[self.namespace]
self._id_to_meta = _shared_meta[self.namespace]
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
with self._storage_lock:
if is_multiprocess:
if self._index.value is None:
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index.value = faiss.IndexFlatIP(self._dim)
else:
if self._index is None:
self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta.update({})
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
@@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Normalize embeddings for cosine similarity (in-place)
faiss.normalize_L2(embeddings)
# Upsert logic:
# 1. Identify which vectors to remove if they exist
# 2. Remove them
# 3. Add the new vectors
existing_ids_to_remove = []
for meta, emb in zip(list_data, embeddings):
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
if faiss_internal_id is not None:
existing_ids_to_remove.append(faiss_internal_id)
with self._storage_lock:
# Upsert logic:
# 1. Identify which vectors to remove if they exist
# 2. Remove them
# 3. Add the new vectors
existing_ids_to_remove = []
for meta, emb in zip(list_data, embeddings):
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
if faiss_internal_id is not None:
existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove)
if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors
start_idx = self._index.ntotal
self._index.add(embeddings)
# Step 2: Add new vectors
start_idx = (self._index.value if is_multiprocess else self._index).ntotal
if is_multiprocess:
self._index.value.add(embeddings)
else:
self._index.add(embeddings)
# Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data):
fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta[fid] = meta
# Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data):
fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta.update({fid: meta})
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data]
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""
@@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
# Perform the similarity search
distances, indices = self._index.search(embedding, top_k)
with self._storage_lock:
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
distances = distances[0]
indices = indices[0]
distances = distances[0]
indices = indices[0]
results = []
for dist, idx in zip(distances, indices):
if idx == -1:
# Faiss returns -1 if no neighbor
continue
results = []
for dist, idx in zip(distances, indices):
if idx == -1:
# Faiss returns -1 if no neighbor
continue
# Cosine similarity threshold
if dist < self.cosine_better_than_threshold:
continue
# Cosine similarity threshold
if dist < self.cosine_better_than_threshold:
continue
meta = self._id_to_meta.get(idx, {})
results.append(
{
**meta,
"id": meta.get("__id__"),
"distance": float(dist),
"created_at": meta.get("__created_at__"),
}
)
meta = self._id_to_meta.get(idx, {})
results.append(
{
**meta,
"id": meta.get("__id__"),
"distance": float(dist),
"created_at": meta.get("__created_at__"),
}
)
return results
return results
@property
def client_storage(self):
# Return whatever structure LightRAG might need for debugging
return {"data": list(self._id_to_meta.values())}
with self._storage_lock:
return {"data": list(self._id_to_meta.values())}
async def delete(self, ids: list[str]):
"""
Delete vectors for the provided custom IDs.
"""
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
to_remove = []
for cid in ids:
fid = self._find_faiss_id_by_custom_id(cid)
if fid is not None:
to_remove.append(fid)
with self._storage_lock:
to_remove = []
for cid in ids:
fid = self._find_faiss_id_by_custom_id(cid)
if fid is not None:
to_remove.append(fid)
if to_remove:
self._remove_faiss_ids(to_remove)
logger.info(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
if to_remove:
self._remove_faiss_ids(to_remove)
logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
@@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
Delete relations for a given entity by scanning metadata.
"""
logger.debug(f"Searching relations for entity {entity_name}")
relations = []
for fid, meta in self._id_to_meta.items():
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
relations.append(fid)
with self._storage_lock:
relations = []
for fid, meta in self._id_to_meta.items():
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
relations.append(fid)
logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations:
self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations:
self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self) -> None:
self._save_faiss_index()
with self._storage_lock:
self._save_faiss_index()
# --------------------------------------------------------------------------------
# Internal helper methods
@@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
"""
Return the Faiss internal ID for a given custom ID, or None if not found.
"""
for fid, meta in self._id_to_meta.items():
if meta.get("__id__") == custom_id:
return fid
return None
with self._storage_lock:
for fid, meta in self._id_to_meta.items():
if meta.get("__id__") == custom_id:
return fid
return None
def _remove_faiss_ids(self, fid_list):
"""
@@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage):
Because IndexFlatIP doesn't support 'removals',
we rebuild the index excluding those vectors.
"""
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
with self._storage_lock:
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
# Rebuild the index
vectors_to_keep = []
new_id_to_meta = {}
for new_fid, old_fid in enumerate(keep_fids):
vec_meta = self._id_to_meta[old_fid]
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
# Rebuild the index
vectors_to_keep = []
new_id_to_meta = {}
for new_fid, old_fid in enumerate(keep_fids):
vec_meta = self._id_to_meta[old_fid]
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
# Re-init index
self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32)
self._index.add(arr)
# Re-init index
new_index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32)
new_index.add(arr)
if is_multiprocess:
self._index.value = new_index
else:
self._index = new_index
self._id_to_meta = new_id_to_meta
self._id_to_meta.update(new_id_to_meta)
def _save_faiss_index(self):
"""
Save the current Faiss index + metadata to disk so it can persist across runs.
"""
faiss.write_index(self._index, self._faiss_index_file)
with self._storage_lock:
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
# We'll keep the int -> dict, but JSON requires string keys.
serializable_dict = {}
for fid, meta in self._id_to_meta.items():
serializable_dict[str(fid)] = meta
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
# We'll keep the int -> dict, but JSON requires string keys.
serializable_dict = {}
for fid, meta in self._id_to_meta.items():
serializable_dict[str(fid)] = meta
with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f)
with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f)
def _load_faiss_index(self):
"""
@@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
try:
# Load the Faiss index
self._index = faiss.read_index(self._faiss_index_file)
loaded_index = faiss.read_index(self._faiss_index_file)
if is_multiprocess:
self._index.value = loaded_index
else:
self._index = loaded_index
# Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
# Convert string keys back to int
self._id_to_meta = {}
self._id_to_meta.update({})
for fid_str, meta in stored_dict.items():
fid = int(fid_str)
self._id_to_meta[fid] = meta
logger.info(
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
new_index = faiss.IndexFlatIP(self._dim)
if is_multiprocess:
self._index.value = new_index
else:
self._index = new_index
self._id_to_meta.update({})

View File

@@ -1,7 +1,6 @@
from dataclasses import dataclass
import os
from typing import Any, Union, final
import threading
from lightrag.base import (
DocProcessingStatus,
@@ -13,26 +12,7 @@ from lightrag.utils import (
logger,
write_json,
)
from lightrag.api.utils_api import manager as main_process_manager
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_doc_status_data = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_doc_status_data
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_doc_status_data = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
from .shared_storage import get_namespace_data, get_storage_lock
@final
@@ -43,45 +23,32 @@ class JsonDocStatusStorage(DocStatusStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
# Ensure manager is initialized
_get_manager()
# Get or create namespace data
if self.namespace not in _shared_doc_status_data:
with _init_lock:
if self.namespace not in _shared_doc_status_data:
try:
initial_data = load_json(self._file_name) or {}
_shared_doc_status_data[self.namespace] = initial_data
except Exception as e:
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
raise RuntimeError(f"Shared data initialization failed: {e}")
try:
self._data = _shared_doc_status_data[self.namespace]
logger.info(f"Loaded document status storage with {len(self._data)} records")
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
self._storage_lock = get_storage_lock()
self._data = get_namespace_data(self.namespace)
with self._storage_lock:
self._data.update(load_json(self._file_name) or {})
logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
return set(keys) - set(self._data.keys())
with self._storage_lock:
return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
for id in ids:
data = self._data.get(id, None)
if data:
result.append(data)
with self._storage_lock:
for id in ids:
data = self._data.get(id, None)
if data:
result.append(data)
return result
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus}
for doc in self._data.values():
counts[doc["status"]] += 1
with self._storage_lock:
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_docs_by_status(
@@ -89,39 +56,46 @@ class JsonDocStatusStorage(DocStatusStorage):
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
result = {}
for k, v in self._data.items():
if v["status"] == status.value:
try:
# Make a copy of the data to avoid modifying the original
data = v.copy()
# If content is missing, use content_summary as content
if "content" not in data and "content_summary" in data:
data["content"] = data["content_summary"]
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}")
continue
with self._storage_lock:
for k, v in self._data.items():
if v["status"] == status.value:
try:
# Make a copy of the data to avoid modifying the original
data = v.copy()
# If content is missing, use content_summary as content
if "content" not in data and "content_summary" in data:
data["content"] = data["content_summary"]
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}")
continue
return result
async def index_done_callback(self) -> None:
write_json(self._data, self._file_name)
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
with self._storage_lock:
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
self._data.update(data)
with self._storage_lock:
self._data.update(data)
await self.index_done_callback()
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.get(id)
with self._storage_lock:
return self._data.get(id)
async def delete(self, doc_ids: list[str]):
for doc_id in doc_ids:
self._data.pop(doc_id, None)
with self._storage_lock:
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await self.index_done_callback()
async def drop(self) -> None:
"""Drop the storage"""
self._data.clear()
with self._storage_lock:
self._data.clear()

View File

@@ -1,8 +1,6 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, final
import threading
from lightrag.base import (
BaseKVStorage,
@@ -12,26 +10,7 @@ from lightrag.utils import (
logger,
write_json,
)
from lightrag.api.utils_api import manager as main_process_manager
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_kv_data = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_kv_data
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_kv_data = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
from .shared_storage import get_namespace_data, get_storage_lock
@final
@@ -39,57 +18,49 @@ def _get_manager():
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._lock = asyncio.Lock()
# Ensure manager is initialized
_get_manager()
# Get or create namespace data
if self.namespace not in _shared_kv_data:
with _init_lock:
if self.namespace not in _shared_kv_data:
try:
initial_data = load_json(self._file_name) or {}
_shared_kv_data[self.namespace] = initial_data
except Exception as e:
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
raise RuntimeError(f"Shared data initialization failed: {e}")
try:
self._data = _shared_kv_data[self.namespace]
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
self._storage_lock = get_storage_lock()
self._data = get_namespace_data(self.namespace)
with self._storage_lock:
if not self._data:
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def index_done_callback(self) -> None:
write_json(self._data, self._file_name)
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
with self._storage_lock:
write_json(self._data, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return self._data.get(id)
with self._storage_lock:
return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return [
(
{k: v for k, v in self._data[id].items()}
if self._data.get(id, None)
else None
)
for id in ids
]
with self._storage_lock:
return [
(
{k: v for k, v in self._data[id].items()}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, keys: set[str]) -> set[str]:
return set(keys) - set(self._data.keys())
with self._storage_lock:
return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
with self._storage_lock:
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
async def delete(self, ids: list[str]) -> None:
for doc_id in ids:
self._data.pop(doc_id, None)
with self._storage_lock:
for doc_id in ids:
self._data.pop(doc_id, None)
await self.index_done_callback()

View File

@@ -3,50 +3,29 @@ import os
from typing import Any, final
from dataclasses import dataclass
import numpy as np
import threading
import time
from lightrag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.api.utils_api import manager as main_process_manager
import pipmaster as pm
from lightrag.base import (
BaseVectorStorage,
)
from lightrag.base import BaseVectorStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_vector_clients = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_vector_clients
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_vector_clients = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()
self._storage_lock = get_storage_lock()
# Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -61,28 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
self._max_batch_size = self.global_config["embedding_batch_num"]
# Ensure manager is initialized
_get_manager()
self._client = get_namespace_object(self.namespace)
# Get or create namespace client
if self.namespace not in _shared_vector_clients:
with _init_lock:
if self.namespace not in _shared_vector_clients:
try:
client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name
)
_shared_vector_clients[self.namespace] = client
except Exception as e:
logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
raise RuntimeError(f"Vector DB client initialization failed: {e}")
with self._storage_lock:
if is_multiprocess:
if self._client.value is None:
self._client.value = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
else:
if self._client is None:
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
try:
self._client = _shared_vector_clients[self.namespace]
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
def _get_client(self):
"""Get the appropriate client instance based on multiprocess mode"""
if is_multiprocess:
return self._client.value
return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
@@ -104,6 +82,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
for i in range(0, len(contents), self._max_batch_size)
]
# Execute embedding outside of lock to avoid long lock times
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = await asyncio.gather(*embedding_tasks)
@@ -111,7 +90,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
with self._storage_lock:
client = self._get_client()
results = client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
@@ -120,27 +101,32 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{
**dp,
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
}
for dp in results
]
with self._storage_lock:
client = self._get_client()
results = client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{
**dp,
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
}
for dp in results
]
return results
@property
def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage")
client = self._get_client()
return getattr(client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
@@ -149,8 +135,10 @@ class NanoVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted
"""
try:
self._client.delete(ids)
logger.info(
with self._storage_lock:
client = self._get_client()
client.delete(ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
@@ -162,35 +150,42 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Check if the entity exists
if self._client.get([entity_id]):
await self.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
with self._storage_lock:
client = self._get_client()
# Check if the entity exists
if client.get([entity_id]):
client.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
try:
relations = [
dp
for dp in self.client_storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations]
with self._storage_lock:
client = self._get_client()
storage = getattr(client, "_NanoVectorDB__storage")
relations = [
dp
for dp in storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
await self.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
if ids_to_delete:
client.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self) -> None:
async with self._save_lock:
self._client.save()
with self._storage_lock:
client = self._get_client()
client.save()

View File

@@ -1,18 +1,13 @@
import os
from dataclasses import dataclass
from typing import Any, final
import threading
import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import (
logger,
)
from lightrag.api.utils_api import manager as main_process_manager
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
from lightrag.base import (
BaseGraphStorage,
)
import pipmaster as pm
if not pm.is_installed("networkx"):
@@ -24,25 +19,6 @@ if not pm.is_installed("graspologic"):
import networkx as nx
from graspologic import embed
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_graphs = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_graphs
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_graphs = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -97,76 +73,98 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
# Ensure manager is initialized
_get_manager()
# Get or create namespace graph
if self.namespace not in _shared_graphs:
with _init_lock:
if self.namespace not in _shared_graphs:
try:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
_shared_graphs[self.namespace] = preloaded_graph or nx.Graph()
except Exception as e:
logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}")
raise RuntimeError(f"Graph initialization failed: {e}")
try:
self._graph = _shared_graphs[self.namespace]
self._node_embed_algorithms = {
self._storage_lock = get_storage_lock()
self._graph = get_namespace_object(self.namespace)
with self._storage_lock:
if is_multiprocess:
if self._graph.value is None:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
self._graph.value = preloaded_graph or nx.Graph()
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
if self._graph is None:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
self._graph = preloaded_graph or nx.Graph()
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
}
def _get_graph(self):
"""Get the appropriate graph instance based on multiprocess mode"""
if is_multiprocess:
return self._graph.value
return self._graph
async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
with self._storage_lock:
graph = self._get_graph()
NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
with self._storage_lock:
graph = self._get_graph()
return graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
with self._storage_lock:
graph = self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id)
with self._storage_lock:
graph = self._get_graph()
return graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id)
with self._storage_lock:
graph = self._get_graph()
return graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
with self._storage_lock:
graph = self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id))
with self._storage_lock:
graph = self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
with self._storage_lock:
graph = self._get_graph()
if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data)
with self._storage_lock:
graph = self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
with self._storage_lock:
graph = self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None:
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
with self._storage_lock:
graph = self._get_graph()
if graph.has_node(node_id):
graph.remove_node(node_id)
logger.debug(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(
self, algorithm: str
@@ -175,14 +173,15 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
# TODO: NOT USED
async def _node2vec_embed(self):
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
with self._storage_lock:
graph = self._get_graph()
embeddings, nodes = embed.node2vec_embed(
graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]):
@@ -191,9 +190,11 @@ class NetworkXStorage(BaseGraphStorage):
Args:
nodes: List of node IDs to be deleted
"""
for node in nodes:
if self._graph.has_node(node):
self._graph.remove_node(node)
with self._storage_lock:
graph = self._get_graph()
for node in nodes:
if graph.has_node(node):
graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
@@ -201,9 +202,11 @@ class NetworkXStorage(BaseGraphStorage):
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target)
with self._storage_lock:
graph = self._get_graph()
for source, target in edges:
if graph.has_edge(source, target):
graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
"""
@@ -211,9 +214,11 @@ class NetworkXStorage(BaseGraphStorage):
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
labels = set()
for node in self._graph.nodes():
labels.add(str(node)) # Add node id as a label
with self._storage_lock:
graph = self._get_graph()
labels = set()
for node in graph.nodes():
labels.add(str(node)) # Add node id as a label
# Return sorted list
return sorted(list(labels))
@@ -235,87 +240,86 @@ class NetworkXStorage(BaseGraphStorage):
seen_nodes = set()
seen_edges = set()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
subgraph = (
self._graph.copy()
) # Create a copy to avoid modifying the original graph
else:
# Find nodes with matching node id (partial match)
nodes_to_explore = []
for n, attr in self._graph.nodes(data=True):
if node_label in str(n): # Use partial matching
nodes_to_explore.append(n)
with self._storage_lock:
graph = self._get_graph()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
else:
# Find nodes with matching node id (partial match)
nodes_to_explore = []
for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}")
return result
if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}")
return result
# Get subgraph using ego_graph
subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth)
# Get subgraph using ego_graph
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500
if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
)
# Add nodes to result
for node in subgraph.nodes():
if str(node) in seen_nodes:
continue
node_data = dict(subgraph.nodes[node])
# Get entity_type as labels
labels = []
if "entity_type" in node_data:
if isinstance(node_data["entity_type"], list):
labels.extend(node_data["entity_type"])
else:
labels.append(node_data["entity_type"])
# Create node with properties
node_properties = {k: v for k, v in node_data.items()}
result.nodes.append(
KnowledgeGraphNode(
id=str(node), labels=[str(node)], properties=node_properties
# Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500
if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
)
)
seen_nodes.add(str(node))
# Add edges to result
for edge in subgraph.edges():
source, target = edge
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue
# Add nodes to result
for node in subgraph.nodes():
if str(node) in seen_nodes:
continue
edge_data = dict(subgraph.edges[edge])
node_data = dict(subgraph.nodes[node])
# Get entity_type as labels
labels = []
if "entity_type" in node_data:
if isinstance(node_data["entity_type"], list):
labels.extend(node_data["entity_type"])
else:
labels.append(node_data["entity_type"])
# Create edge with complete information
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source),
target=str(target),
properties=edge_data,
# Create node with properties
node_properties = {k: v for k, v in node_data.items()}
result.nodes.append(
KnowledgeGraphNode(
id=str(node), labels=[str(node)], properties=node_properties
)
)
)
seen_edges.add(edge_id)
seen_nodes.add(str(node))
# logger.info(result.edges)
# Add edges to result
for edge in subgraph.edges():
source, target = edge
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue
edge_data = dict(subgraph.edges[edge])
# Create edge with complete information
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source),
target=str(target),
properties=edge_data,
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"

View File

@@ -0,0 +1,94 @@
from multiprocessing.synchronize import Lock as ProcessLock
from threading import Lock as ThreadLock
from multiprocessing import Manager
from typing import Any, Dict, Optional, Union
# 定义类型变量
LockType = Union[ProcessLock, ThreadLock]
# 全局变量
_shared_data: Optional[Dict[str, Any]] = None
_namespace_objects: Optional[Dict[str, Any]] = None
_global_lock: Optional[LockType] = None
is_multiprocess = False
manager = None
def initialize_manager():
"""Initialize manager, only for multiple processes where workers > 1"""
global manager
if manager is None:
manager = Manager()
def _get_global_lock() -> LockType:
global _global_lock, is_multiprocess
if _global_lock is None:
if is_multiprocess:
_global_lock = manager.Lock()
else:
_global_lock = ThreadLock()
return _global_lock
def get_storage_lock() -> LockType:
"""return storage lock for data consistency"""
return _get_global_lock()
def get_scan_lock() -> LockType:
"""return scan_progress lock for data consistency"""
return get_storage_lock()
def get_shared_data() -> Dict[str, Any]:
"""
return shared data for all storage types
create mult-process save share data only if need for better performance
"""
global _shared_data, is_multiprocess
if _shared_data is None:
lock = _get_global_lock()
with lock:
if _shared_data is None:
if is_multiprocess:
_shared_data = manager.dict()
else:
_shared_data = {}
return _shared_data
def get_namespace_object(namespace: str) -> Any:
"""Get an object for specific namespace"""
global _namespace_objects, is_multiprocess
if _namespace_objects is None:
lock = _get_global_lock()
with lock:
if _namespace_objects is None:
_namespace_objects = {}
if namespace not in _namespace_objects:
lock = _get_global_lock()
with lock:
if namespace not in _namespace_objects:
if is_multiprocess:
_namespace_objects[namespace] = manager.Value('O', None)
else:
_namespace_objects[namespace] = None
return _namespace_objects[namespace]
def get_namespace_data(namespace: str) -> Dict[str, Any]:
"""get storage space for specific storage type(namespace)"""
shared_data = get_shared_data()
lock = _get_global_lock()
if namespace not in shared_data:
with lock:
if namespace not in shared_data:
shared_data[namespace] = {}
return shared_data[namespace]
def get_scan_progress() -> Dict[str, Any]:
"""get storage space for document scanning progress data"""
return get_namespace_data('scan_progress')

View File

@@ -266,13 +266,7 @@ class LightRAG:
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self):
# Initialize manager if needed
from lightrag.api.utils_api import manager, initialize_manager
if manager is None:
initialize_manager()
logger.info("Initialized manager for single process mode")
def __post_init__(self):
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path, self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}")