Refactor storage implementations to support both single and multi-process modes
• Add shared storage management module • Support process/thread lock based on mode
This commit is contained in:
@@ -406,9 +406,6 @@ def create_app(args):
|
||||
|
||||
def get_application():
|
||||
"""Factory function for creating the FastAPI application"""
|
||||
from .utils_api import initialize_manager
|
||||
initialize_manager()
|
||||
|
||||
# Get args from environment variable
|
||||
args_json = os.environ.get('LIGHTRAG_ARGS')
|
||||
if not args_json:
|
||||
@@ -428,6 +425,12 @@ def main():
|
||||
# Save args to environment variable for child processes
|
||||
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
|
||||
|
||||
if args.workers > 1:
|
||||
from lightrag.kg.shared_storage import initialize_manager
|
||||
initialize_manager()
|
||||
import lightrag.kg.shared_storage as shared_storage
|
||||
shared_storage.is_multiprocess = True
|
||||
|
||||
# Configure uvicorn logging
|
||||
logging.config.dictConfig({
|
||||
"version": 1,
|
||||
|
@@ -18,12 +18,10 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from ..utils_api import (
|
||||
get_api_key_dependency,
|
||||
scan_progress,
|
||||
update_scan_progress_if_not_scanning,
|
||||
update_scan_progress,
|
||||
reset_scan_progress,
|
||||
from ..utils_api import get_api_key_dependency
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_scan_progress,
|
||||
get_scan_lock,
|
||||
)
|
||||
|
||||
|
||||
@@ -378,23 +376,51 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||
|
||||
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
"""Background task to scan and index documents"""
|
||||
if not update_scan_progress_if_not_scanning():
|
||||
ASCIIColors.info(
|
||||
"Skip document scanning(another scanning is active)"
|
||||
)
|
||||
return
|
||||
scan_progress = get_scan_progress()
|
||||
scan_lock = get_scan_lock()
|
||||
|
||||
with scan_lock:
|
||||
if scan_progress["is_scanning"]:
|
||||
ASCIIColors.info(
|
||||
"Skip document scanning(another scanning is active)"
|
||||
)
|
||||
return
|
||||
scan_progress.update({
|
||||
"is_scanning": True,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
})
|
||||
|
||||
try:
|
||||
new_files = doc_manager.scan_directory_for_new_files()
|
||||
total_files = len(new_files)
|
||||
update_scan_progress("", total_files, 0) # Initialize progress
|
||||
scan_progress.update({
|
||||
"current_file": "",
|
||||
"total_files": total_files,
|
||||
"indexed_count": 0,
|
||||
"progress": 0,
|
||||
})
|
||||
|
||||
logging.info(f"Found {total_files} new files to index.")
|
||||
for idx, file_path in enumerate(new_files):
|
||||
try:
|
||||
update_scan_progress(os.path.basename(file_path), total_files, idx)
|
||||
progress = (idx / total_files * 100) if total_files > 0 else 0
|
||||
scan_progress.update({
|
||||
"current_file": os.path.basename(file_path),
|
||||
"indexed_count": idx,
|
||||
"progress": progress,
|
||||
})
|
||||
|
||||
await pipeline_index_file(rag, file_path)
|
||||
update_scan_progress(os.path.basename(file_path), total_files, idx + 1)
|
||||
|
||||
progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
|
||||
scan_progress.update({
|
||||
"current_file": os.path.basename(file_path),
|
||||
"indexed_count": idx + 1,
|
||||
"progress": progress,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||
@@ -402,7 +428,13 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
except Exception as e:
|
||||
logging.error(f"Error during scanning process: {str(e)}")
|
||||
finally:
|
||||
reset_scan_progress()
|
||||
scan_progress.update({
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
})
|
||||
|
||||
|
||||
def create_document_routes(
|
||||
@@ -427,7 +459,7 @@ def create_document_routes(
|
||||
return {"status": "scanning_started"}
|
||||
|
||||
@router.get("/scan-progress")
|
||||
async def get_scan_progress():
|
||||
async def get_scanning_progress():
|
||||
"""
|
||||
Get the current progress of the document scanning process.
|
||||
|
||||
@@ -439,7 +471,7 @@ def create_document_routes(
|
||||
- total_files: Total number of files to process
|
||||
- progress: Percentage of completion
|
||||
"""
|
||||
return dict(scan_progress)
|
||||
return dict(get_scan_progress())
|
||||
|
||||
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
||||
async def upload_to_input_dir(
|
||||
|
@@ -6,7 +6,6 @@ import os
|
||||
import argparse
|
||||
from typing import Optional
|
||||
import sys
|
||||
from multiprocessing import Manager
|
||||
from ascii_colors import ASCIIColors
|
||||
from lightrag.api import __api_version__
|
||||
from fastapi import HTTPException, Security
|
||||
@@ -17,66 +16,6 @@ from starlette.status import HTTP_403_FORBIDDEN
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Global variables for manager and shared state
|
||||
manager = None
|
||||
scan_progress = None
|
||||
scan_lock = None
|
||||
|
||||
def initialize_manager():
|
||||
"""Initialize manager and shared state for cross-process communication"""
|
||||
global manager, scan_progress, scan_lock
|
||||
if manager is None:
|
||||
manager = Manager()
|
||||
scan_progress = manager.dict({
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
})
|
||||
scan_lock = manager.Lock()
|
||||
|
||||
def update_scan_progress_if_not_scanning():
|
||||
"""
|
||||
Atomically check if scanning is not in progress and update scan_progress if it's not.
|
||||
Returns True if the update was successful, False if scanning was already in progress.
|
||||
"""
|
||||
with scan_lock:
|
||||
if not scan_progress["is_scanning"]:
|
||||
scan_progress.update({
|
||||
"is_scanning": True,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
})
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_scan_progress(current_file: str, total_files: int, indexed_count: int):
|
||||
"""
|
||||
Atomically update scan progress information.
|
||||
"""
|
||||
progress = (indexed_count / total_files * 100) if total_files > 0 else 0
|
||||
scan_progress.update({
|
||||
"current_file": current_file,
|
||||
"indexed_count": indexed_count,
|
||||
"total_files": total_files,
|
||||
"progress": progress,
|
||||
})
|
||||
|
||||
def reset_scan_progress():
|
||||
"""
|
||||
Atomically reset scan progress to initial state.
|
||||
"""
|
||||
scan_progress.update({
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
})
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
|
@@ -2,48 +2,21 @@ import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any, final
|
||||
import threading
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
from lightrag.api.utils_api import manager as main_process_manager
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
)
|
||||
from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
from lightrag.utils import logger,compute_mdhash_id
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
|
||||
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
import faiss # type: ignore
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_indices = None
|
||||
_shared_meta = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_indices, _shared_meta
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = main_process_manager
|
||||
_shared_indices = _manager.dict()
|
||||
_shared_meta = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Embedding dimension (e.g. 768) must match your embedding function
|
||||
self._dim = self.embedding_func.embedding_dim
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
self._index = get_namespace_object('faiss_indices')
|
||||
self._id_to_meta = get_namespace_data('faiss_meta')
|
||||
|
||||
# Get or create namespace index and metadata
|
||||
if self.namespace not in _shared_indices:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_indices:
|
||||
try:
|
||||
# Create an empty Faiss index for inner product
|
||||
index = faiss.IndexFlatIP(self._dim)
|
||||
meta = {}
|
||||
|
||||
# Load existing index if available
|
||||
if os.path.exists(self._faiss_index_file):
|
||||
try:
|
||||
index = faiss.read_index(self._faiss_index_file)
|
||||
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||
stored_dict = json.load(f)
|
||||
# Convert string keys back to int
|
||||
meta = {int(k): v for k, v in stored_dict.items()}
|
||||
logger.info(
|
||||
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
||||
logger.warning("Starting with an empty Faiss index.")
|
||||
index = faiss.IndexFlatIP(self._dim)
|
||||
meta = {}
|
||||
|
||||
_shared_indices[self.namespace] = index
|
||||
_shared_meta[self.namespace] = meta
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Faiss index initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._index = _shared_indices[self.namespace]
|
||||
self._id_to_meta = _shared_meta[self.namespace]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
with self._storage_lock:
|
||||
if is_multiprocess:
|
||||
if self._index.value is None:
|
||||
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
||||
# For demonstration, we use a simple IndexFlatIP.
|
||||
self._index.value = faiss.IndexFlatIP(self._dim)
|
||||
else:
|
||||
if self._index is None:
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
|
||||
# Keep a local store for metadata, IDs, etc.
|
||||
# Maps <int faiss_id> → metadata (including your original ID).
|
||||
self._id_to_meta.update({})
|
||||
|
||||
# Attempt to load an existing index + metadata from disk
|
||||
self._load_faiss_index()
|
||||
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
@@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
# Normalize embeddings for cosine similarity (in-place)
|
||||
faiss.normalize_L2(embeddings)
|
||||
|
||||
# Upsert logic:
|
||||
# 1. Identify which vectors to remove if they exist
|
||||
# 2. Remove them
|
||||
# 3. Add the new vectors
|
||||
existing_ids_to_remove = []
|
||||
for meta, emb in zip(list_data, embeddings):
|
||||
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
|
||||
if faiss_internal_id is not None:
|
||||
existing_ids_to_remove.append(faiss_internal_id)
|
||||
with self._storage_lock:
|
||||
# Upsert logic:
|
||||
# 1. Identify which vectors to remove if they exist
|
||||
# 2. Remove them
|
||||
# 3. Add the new vectors
|
||||
existing_ids_to_remove = []
|
||||
for meta, emb in zip(list_data, embeddings):
|
||||
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
|
||||
if faiss_internal_id is not None:
|
||||
existing_ids_to_remove.append(faiss_internal_id)
|
||||
|
||||
if existing_ids_to_remove:
|
||||
self._remove_faiss_ids(existing_ids_to_remove)
|
||||
if existing_ids_to_remove:
|
||||
self._remove_faiss_ids(existing_ids_to_remove)
|
||||
|
||||
# Step 2: Add new vectors
|
||||
start_idx = self._index.ntotal
|
||||
self._index.add(embeddings)
|
||||
# Step 2: Add new vectors
|
||||
start_idx = (self._index.value if is_multiprocess else self._index).ntotal
|
||||
if is_multiprocess:
|
||||
self._index.value.add(embeddings)
|
||||
else:
|
||||
self._index.add(embeddings)
|
||||
|
||||
# Step 3: Store metadata + vector for each new ID
|
||||
for i, meta in enumerate(list_data):
|
||||
fid = start_idx + i
|
||||
# Store the raw vector so we can rebuild if something is removed
|
||||
meta["__vector__"] = embeddings[i].tolist()
|
||||
self._id_to_meta[fid] = meta
|
||||
# Step 3: Store metadata + vector for each new ID
|
||||
for i, meta in enumerate(list_data):
|
||||
fid = start_idx + i
|
||||
# Store the raw vector so we can rebuild if something is removed
|
||||
meta["__vector__"] = embeddings[i].tolist()
|
||||
self._id_to_meta.update({fid: meta})
|
||||
|
||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||
return [m["__id__"] for m in list_data]
|
||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
@@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
# Perform the similarity search
|
||||
distances, indices = self._index.search(embedding, top_k)
|
||||
with self._storage_lock:
|
||||
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
|
||||
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
|
||||
results = []
|
||||
for dist, idx in zip(distances, indices):
|
||||
if idx == -1:
|
||||
# Faiss returns -1 if no neighbor
|
||||
continue
|
||||
results = []
|
||||
for dist, idx in zip(distances, indices):
|
||||
if idx == -1:
|
||||
# Faiss returns -1 if no neighbor
|
||||
continue
|
||||
|
||||
# Cosine similarity threshold
|
||||
if dist < self.cosine_better_than_threshold:
|
||||
continue
|
||||
# Cosine similarity threshold
|
||||
if dist < self.cosine_better_than_threshold:
|
||||
continue
|
||||
|
||||
meta = self._id_to_meta.get(idx, {})
|
||||
results.append(
|
||||
{
|
||||
**meta,
|
||||
"id": meta.get("__id__"),
|
||||
"distance": float(dist),
|
||||
"created_at": meta.get("__created_at__"),
|
||||
}
|
||||
)
|
||||
meta = self._id_to_meta.get(idx, {})
|
||||
results.append(
|
||||
{
|
||||
**meta,
|
||||
"id": meta.get("__id__"),
|
||||
"distance": float(dist),
|
||||
"created_at": meta.get("__created_at__"),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
# Return whatever structure LightRAG might need for debugging
|
||||
return {"data": list(self._id_to_meta.values())}
|
||||
with self._storage_lock:
|
||||
return {"data": list(self._id_to_meta.values())}
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""
|
||||
Delete vectors for the provided custom IDs.
|
||||
"""
|
||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||
to_remove = []
|
||||
for cid in ids:
|
||||
fid = self._find_faiss_id_by_custom_id(cid)
|
||||
if fid is not None:
|
||||
to_remove.append(fid)
|
||||
with self._storage_lock:
|
||||
to_remove = []
|
||||
for cid in ids:
|
||||
fid = self._find_faiss_id_by_custom_id(cid)
|
||||
if fid is not None:
|
||||
to_remove.append(fid)
|
||||
|
||||
if to_remove:
|
||||
self._remove_faiss_ids(to_remove)
|
||||
logger.info(
|
||||
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
||||
)
|
||||
if to_remove:
|
||||
self._remove_faiss_ids(to_remove)
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
@@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
Delete relations for a given entity by scanning metadata.
|
||||
"""
|
||||
logger.debug(f"Searching relations for entity {entity_name}")
|
||||
relations = []
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
||||
relations.append(fid)
|
||||
with self._storage_lock:
|
||||
relations = []
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
||||
relations.append(fid)
|
||||
|
||||
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
||||
if relations:
|
||||
self._remove_faiss_ids(relations)
|
||||
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
||||
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
||||
if relations:
|
||||
self._remove_faiss_ids(relations)
|
||||
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
self._save_faiss_index()
|
||||
with self._storage_lock:
|
||||
self._save_faiss_index()
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Internal helper methods
|
||||
@@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
Return the Faiss internal ID for a given custom ID, or None if not found.
|
||||
"""
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
if meta.get("__id__") == custom_id:
|
||||
return fid
|
||||
return None
|
||||
with self._storage_lock:
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
if meta.get("__id__") == custom_id:
|
||||
return fid
|
||||
return None
|
||||
|
||||
def _remove_faiss_ids(self, fid_list):
|
||||
"""
|
||||
@@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
Because IndexFlatIP doesn't support 'removals',
|
||||
we rebuild the index excluding those vectors.
|
||||
"""
|
||||
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
||||
with self._storage_lock:
|
||||
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
||||
|
||||
# Rebuild the index
|
||||
vectors_to_keep = []
|
||||
new_id_to_meta = {}
|
||||
for new_fid, old_fid in enumerate(keep_fids):
|
||||
vec_meta = self._id_to_meta[old_fid]
|
||||
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
||||
new_id_to_meta[new_fid] = vec_meta
|
||||
# Rebuild the index
|
||||
vectors_to_keep = []
|
||||
new_id_to_meta = {}
|
||||
for new_fid, old_fid in enumerate(keep_fids):
|
||||
vec_meta = self._id_to_meta[old_fid]
|
||||
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
||||
new_id_to_meta[new_fid] = vec_meta
|
||||
|
||||
# Re-init index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
if vectors_to_keep:
|
||||
arr = np.array(vectors_to_keep, dtype=np.float32)
|
||||
self._index.add(arr)
|
||||
# Re-init index
|
||||
new_index = faiss.IndexFlatIP(self._dim)
|
||||
if vectors_to_keep:
|
||||
arr = np.array(vectors_to_keep, dtype=np.float32)
|
||||
new_index.add(arr)
|
||||
if is_multiprocess:
|
||||
self._index.value = new_index
|
||||
else:
|
||||
self._index = new_index
|
||||
|
||||
self._id_to_meta = new_id_to_meta
|
||||
self._id_to_meta.update(new_id_to_meta)
|
||||
|
||||
def _save_faiss_index(self):
|
||||
"""
|
||||
Save the current Faiss index + metadata to disk so it can persist across runs.
|
||||
"""
|
||||
faiss.write_index(self._index, self._faiss_index_file)
|
||||
with self._storage_lock:
|
||||
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
|
||||
|
||||
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
||||
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
||||
# We'll keep the int -> dict, but JSON requires string keys.
|
||||
serializable_dict = {}
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
serializable_dict[str(fid)] = meta
|
||||
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
||||
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
||||
# We'll keep the int -> dict, but JSON requires string keys.
|
||||
serializable_dict = {}
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
serializable_dict[str(fid)] = meta
|
||||
|
||||
with open(self._meta_file, "w", encoding="utf-8") as f:
|
||||
json.dump(serializable_dict, f)
|
||||
with open(self._meta_file, "w", encoding="utf-8") as f:
|
||||
json.dump(serializable_dict, f)
|
||||
|
||||
def _load_faiss_index(self):
|
||||
"""
|
||||
@@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
try:
|
||||
# Load the Faiss index
|
||||
self._index = faiss.read_index(self._faiss_index_file)
|
||||
loaded_index = faiss.read_index(self._faiss_index_file)
|
||||
if is_multiprocess:
|
||||
self._index.value = loaded_index
|
||||
else:
|
||||
self._index = loaded_index
|
||||
|
||||
# Load metadata
|
||||
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||
stored_dict = json.load(f)
|
||||
|
||||
# Convert string keys back to int
|
||||
self._id_to_meta = {}
|
||||
self._id_to_meta.update({})
|
||||
for fid_str, meta in stored_dict.items():
|
||||
fid = int(fid_str)
|
||||
self._id_to_meta[fid] = meta
|
||||
|
||||
logger.info(
|
||||
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
|
||||
f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
||||
logger.warning("Starting with an empty Faiss index.")
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
new_index = faiss.IndexFlatIP(self._dim)
|
||||
if is_multiprocess:
|
||||
self._index.value = new_index
|
||||
else:
|
||||
self._index = new_index
|
||||
self._id_to_meta.update({})
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Any, Union, final
|
||||
import threading
|
||||
|
||||
from lightrag.base import (
|
||||
DocProcessingStatus,
|
||||
@@ -13,26 +12,7 @@ from lightrag.utils import (
|
||||
logger,
|
||||
write_json,
|
||||
)
|
||||
from lightrag.api.utils_api import manager as main_process_manager
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_doc_status_data = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_doc_status_data
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = main_process_manager
|
||||
_shared_doc_status_data = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
from .shared_storage import get_namespace_data, get_storage_lock
|
||||
|
||||
|
||||
@final
|
||||
@@ -43,45 +23,32 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace data
|
||||
if self.namespace not in _shared_doc_status_data:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_doc_status_data:
|
||||
try:
|
||||
initial_data = load_json(self._file_name) or {}
|
||||
_shared_doc_status_data[self.namespace] = initial_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Shared data initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._data = _shared_doc_status_data[self.namespace]
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = get_namespace_data(self.namespace)
|
||||
with self._storage_lock:
|
||||
self._data.update(load_json(self._file_name) or {})
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
return set(keys) - set(self._data.keys())
|
||||
with self._storage_lock:
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
result: list[dict[str, Any]] = []
|
||||
for id in ids:
|
||||
data = self._data.get(id, None)
|
||||
if data:
|
||||
result.append(data)
|
||||
with self._storage_lock:
|
||||
for id in ids:
|
||||
data = self._data.get(id, None)
|
||||
if data:
|
||||
result.append(data)
|
||||
return result
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status.value: 0 for status in DocStatus}
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
with self._storage_lock:
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
return counts
|
||||
|
||||
async def get_docs_by_status(
|
||||
@@ -89,39 +56,46 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific status"""
|
||||
result = {}
|
||||
for k, v in self._data.items():
|
||||
if v["status"] == status.value:
|
||||
try:
|
||||
# Make a copy of the data to avoid modifying the original
|
||||
data = v.copy()
|
||||
# If content is missing, use content_summary as content
|
||||
if "content" not in data and "content_summary" in data:
|
||||
data["content"] = data["content_summary"]
|
||||
result[k] = DocProcessingStatus(**data)
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing required field for document {k}: {e}")
|
||||
continue
|
||||
with self._storage_lock:
|
||||
for k, v in self._data.items():
|
||||
if v["status"] == status.value:
|
||||
try:
|
||||
# Make a copy of the data to avoid modifying the original
|
||||
data = v.copy()
|
||||
# If content is missing, use content_summary as content
|
||||
if "content" not in data and "content_summary" in data:
|
||||
data["content"] = data["content_summary"]
|
||||
result[k] = DocProcessingStatus(**data)
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing required field for document {k}: {e}")
|
||||
continue
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
|
||||
with self._storage_lock:
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
self._data.update(data)
|
||||
with self._storage_lock:
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return self._data.get(id)
|
||||
with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
with self._storage_lock:
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
self._data.clear()
|
||||
with self._storage_lock:
|
||||
self._data.clear()
|
||||
|
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import threading
|
||||
|
||||
from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
@@ -12,26 +10,7 @@ from lightrag.utils import (
|
||||
logger,
|
||||
write_json,
|
||||
)
|
||||
from lightrag.api.utils_api import manager as main_process_manager
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_kv_data = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_kv_data
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = main_process_manager
|
||||
_shared_kv_data = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
from .shared_storage import get_namespace_data, get_storage_lock
|
||||
|
||||
|
||||
@final
|
||||
@@ -39,57 +18,49 @@ def _get_manager():
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace data
|
||||
if self.namespace not in _shared_kv_data:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_kv_data:
|
||||
try:
|
||||
initial_data = load_json(self._file_name) or {}
|
||||
_shared_kv_data[self.namespace] = initial_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Shared data initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._data = _shared_kv_data[self.namespace]
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = get_namespace_data(self.namespace)
|
||||
with self._storage_lock:
|
||||
if not self._data:
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
|
||||
with self._storage_lock:
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
return self._data.get(id)
|
||||
with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
with self._storage_lock:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
return set(keys) - set(self._data.keys())
|
||||
with self._storage_lock:
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
with self._storage_lock:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
with self._storage_lock:
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await self.index_done_callback()
|
||||
|
@@ -3,50 +3,29 @@ import os
|
||||
from typing import Any, final
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
)
|
||||
from lightrag.api.utils_api import manager as main_process_manager
|
||||
import pipmaster as pm
|
||||
from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
|
||||
|
||||
if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
from nano_vectordb import NanoVectorDB
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_vector_clients = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_vector_clients
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = main_process_manager
|
||||
_shared_vector_clients = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._save_lock = asyncio.Lock()
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
# Use global config value if specified, otherwise use default
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@@ -61,28 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
self._client = get_namespace_object(self.namespace)
|
||||
|
||||
# Get or create namespace client
|
||||
if self.namespace not in _shared_vector_clients:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_vector_clients:
|
||||
try:
|
||||
client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name
|
||||
)
|
||||
_shared_vector_clients[self.namespace] = client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Vector DB client initialization failed: {e}")
|
||||
with self._storage_lock:
|
||||
if is_multiprocess:
|
||||
if self._client.value is None:
|
||||
self._client.value = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
else:
|
||||
if self._client is None:
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
|
||||
try:
|
||||
self._client = _shared_vector_clients[self.namespace]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
|
||||
|
||||
def _get_client(self):
|
||||
"""Get the appropriate client instance based on multiprocess mode"""
|
||||
if is_multiprocess:
|
||||
return self._client.value
|
||||
return self._client
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
@@ -104,6 +82,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
@@ -111,7 +90,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
results = client.upsert(datas=list_data)
|
||||
return results
|
||||
else:
|
||||
# sometimes the embedding is not returned correctly. just log it.
|
||||
@@ -120,27 +101,32 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{
|
||||
**dp,
|
||||
"id": dp["__id__"],
|
||||
"distance": dp["__metrics__"],
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
for dp in results
|
||||
]
|
||||
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
results = client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{
|
||||
**dp,
|
||||
"id": dp["__id__"],
|
||||
"distance": dp["__metrics__"],
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
for dp in results
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
return getattr(self._client, "_NanoVectorDB__storage")
|
||||
client = self._get_client()
|
||||
return getattr(client, "_NanoVectorDB__storage")
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
@@ -149,8 +135,10 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
try:
|
||||
self._client.delete(ids)
|
||||
logger.info(
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
client.delete(ids)
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -162,35 +150,42 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug(
|
||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
# Check if the entity exists
|
||||
if self._client.get([entity_id]):
|
||||
await self.delete([entity_id])
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
# Check if the entity exists
|
||||
if client.get([entity_id]):
|
||||
client.delete([entity_id])
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
try:
|
||||
relations = [
|
||||
dp
|
||||
for dp in self.client_storage["data"]
|
||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||
]
|
||||
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
storage = getattr(client, "_NanoVectorDB__storage")
|
||||
relations = [
|
||||
dp
|
||||
for dp in storage["data"]
|
||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||
]
|
||||
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
if ids_to_delete:
|
||||
await self.delete(ids_to_delete)
|
||||
logger.debug(
|
||||
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No relations found for entity {entity_name}")
|
||||
if ids_to_delete:
|
||||
client.delete(ids_to_delete)
|
||||
logger.debug(
|
||||
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No relations found for entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._save_lock:
|
||||
self._client.save()
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
client.save()
|
||||
|
@@ -1,18 +1,13 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
)
|
||||
from lightrag.api.utils_api import manager as main_process_manager
|
||||
from lightrag.utils import logger
|
||||
from lightrag.base import BaseGraphStorage
|
||||
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
|
||||
|
||||
from lightrag.base import (
|
||||
BaseGraphStorage,
|
||||
)
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("networkx"):
|
||||
@@ -24,25 +19,6 @@ if not pm.is_installed("graspologic"):
|
||||
import networkx as nx
|
||||
from graspologic import embed
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_graphs = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_graphs
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = main_process_manager
|
||||
_shared_graphs = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -97,76 +73,98 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace graph
|
||||
if self.namespace not in _shared_graphs:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_graphs:
|
||||
try:
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
_shared_graphs[self.namespace] = preloaded_graph or nx.Graph()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Graph initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._graph = _shared_graphs[self.namespace]
|
||||
self._node_embed_algorithms = {
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._graph = get_namespace_object(self.namespace)
|
||||
with self._storage_lock:
|
||||
if is_multiprocess:
|
||||
if self._graph.value is None:
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
self._graph.value = preloaded_graph or nx.Graph()
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
else:
|
||||
if self._graph is None:
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
}
|
||||
|
||||
def _get_graph(self):
|
||||
"""Get the appropriate graph instance based on multiprocess mode"""
|
||||
if is_multiprocess:
|
||||
return self._graph.value
|
||||
return self._graph
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._graph.has_node(node_id)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
return graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
return graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
return self._graph.nodes.get(node_id)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
return graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
return self._graph.degree(node_id)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
return graph.degree(node_id)
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
return graph.degree(src_id) + graph.degree(tgt_id)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
return graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
if graph.has_node(source_node_id):
|
||||
return list(graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
if self._graph.has_node(node_id):
|
||||
self._graph.remove_node(node_id)
|
||||
logger.info(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
graph.remove_node(node_id)
|
||||
logger.debug(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
@@ -175,14 +173,15 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
# @TODO: NOT USED
|
||||
# TODO: NOT USED
|
||||
async def _node2vec_embed(self):
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
)
|
||||
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
def remove_nodes(self, nodes: list[str]):
|
||||
@@ -191,9 +190,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
for node in nodes:
|
||||
if self._graph.has_node(node):
|
||||
self._graph.remove_node(node)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
for node in nodes:
|
||||
if graph.has_node(node):
|
||||
graph.remove_node(node)
|
||||
|
||||
def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
@@ -201,9 +202,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
for source, target in edges:
|
||||
if self._graph.has_edge(source, target):
|
||||
self._graph.remove_edge(source, target)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
for source, target in edges:
|
||||
if graph.has_edge(source, target):
|
||||
graph.remove_edge(source, target)
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
@@ -211,9 +214,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
[label1, label2, ...] # Alphabetically sorted label list
|
||||
"""
|
||||
labels = set()
|
||||
for node in self._graph.nodes():
|
||||
labels.add(str(node)) # Add node id as a label
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
labels = set()
|
||||
for node in graph.nodes():
|
||||
labels.add(str(node)) # Add node id as a label
|
||||
|
||||
# Return sorted list
|
||||
return sorted(list(labels))
|
||||
@@ -235,87 +240,86 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = (
|
||||
self._graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
else:
|
||||
# Find nodes with matching node id (partial match)
|
||||
nodes_to_explore = []
|
||||
for n, attr in self._graph.nodes(data=True):
|
||||
if node_label in str(n): # Use partial matching
|
||||
nodes_to_explore.append(n)
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
|
||||
else:
|
||||
# Find nodes with matching node id (partial match)
|
||||
nodes_to_explore = []
|
||||
for n, attr in graph.nodes(data=True):
|
||||
if node_label in str(n): # Use partial matching
|
||||
nodes_to_explore.append(n)
|
||||
|
||||
if not nodes_to_explore:
|
||||
logger.warning(f"No nodes found with label {node_label}")
|
||||
return result
|
||||
if not nodes_to_explore:
|
||||
logger.warning(f"No nodes found with label {node_label}")
|
||||
return result
|
||||
|
||||
# Get subgraph using ego_graph
|
||||
subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth)
|
||||
# Get subgraph using ego_graph
|
||||
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
max_graph_nodes = 500
|
||||
if len(subgraph.nodes()) > max_graph_nodes:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
node_degrees = dict(subgraph.degree())
|
||||
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
||||
:max_graph_nodes
|
||||
]
|
||||
top_node_ids = [node[0] for node in top_nodes]
|
||||
# Create new subgraph with only top nodes
|
||||
subgraph = subgraph.subgraph(top_node_ids)
|
||||
logger.info(
|
||||
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
|
||||
)
|
||||
|
||||
# Add nodes to result
|
||||
for node in subgraph.nodes():
|
||||
if str(node) in seen_nodes:
|
||||
continue
|
||||
|
||||
node_data = dict(subgraph.nodes[node])
|
||||
# Get entity_type as labels
|
||||
labels = []
|
||||
if "entity_type" in node_data:
|
||||
if isinstance(node_data["entity_type"], list):
|
||||
labels.extend(node_data["entity_type"])
|
||||
else:
|
||||
labels.append(node_data["entity_type"])
|
||||
|
||||
# Create node with properties
|
||||
node_properties = {k: v for k, v in node_data.items()}
|
||||
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=str(node), labels=[str(node)], properties=node_properties
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
max_graph_nodes = 500
|
||||
if len(subgraph.nodes()) > max_graph_nodes:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
node_degrees = dict(subgraph.degree())
|
||||
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
||||
:max_graph_nodes
|
||||
]
|
||||
top_node_ids = [node[0] for node in top_nodes]
|
||||
# Create new subgraph with only top nodes
|
||||
subgraph = subgraph.subgraph(top_node_ids)
|
||||
logger.info(
|
||||
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
|
||||
)
|
||||
)
|
||||
seen_nodes.add(str(node))
|
||||
|
||||
# Add edges to result
|
||||
for edge in subgraph.edges():
|
||||
source, target = edge
|
||||
edge_id = f"{source}-{target}"
|
||||
if edge_id in seen_edges:
|
||||
continue
|
||||
# Add nodes to result
|
||||
for node in subgraph.nodes():
|
||||
if str(node) in seen_nodes:
|
||||
continue
|
||||
|
||||
edge_data = dict(subgraph.edges[edge])
|
||||
node_data = dict(subgraph.nodes[node])
|
||||
# Get entity_type as labels
|
||||
labels = []
|
||||
if "entity_type" in node_data:
|
||||
if isinstance(node_data["entity_type"], list):
|
||||
labels.extend(node_data["entity_type"])
|
||||
else:
|
||||
labels.append(node_data["entity_type"])
|
||||
|
||||
# Create edge with complete information
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="DIRECTED",
|
||||
source=str(source),
|
||||
target=str(target),
|
||||
properties=edge_data,
|
||||
# Create node with properties
|
||||
node_properties = {k: v for k, v in node_data.items()}
|
||||
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=str(node), labels=[str(node)], properties=node_properties
|
||||
)
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
seen_nodes.add(str(node))
|
||||
|
||||
# logger.info(result.edges)
|
||||
# Add edges to result
|
||||
for edge in subgraph.edges():
|
||||
source, target = edge
|
||||
edge_id = f"{source}-{target}"
|
||||
if edge_id in seen_edges:
|
||||
continue
|
||||
|
||||
edge_data = dict(subgraph.edges[edge])
|
||||
|
||||
# Create edge with complete information
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="DIRECTED",
|
||||
source=str(source),
|
||||
target=str(target),
|
||||
properties=edge_data,
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
|
94
lightrag/kg/shared_storage.py
Normal file
94
lightrag/kg/shared_storage.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from multiprocessing.synchronize import Lock as ProcessLock
|
||||
from threading import Lock as ThreadLock
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
# 定义类型变量
|
||||
LockType = Union[ProcessLock, ThreadLock]
|
||||
|
||||
# 全局变量
|
||||
_shared_data: Optional[Dict[str, Any]] = None
|
||||
_namespace_objects: Optional[Dict[str, Any]] = None
|
||||
_global_lock: Optional[LockType] = None
|
||||
is_multiprocess = False
|
||||
manager = None
|
||||
|
||||
def initialize_manager():
|
||||
"""Initialize manager, only for multiple processes where workers > 1"""
|
||||
global manager
|
||||
if manager is None:
|
||||
manager = Manager()
|
||||
|
||||
def _get_global_lock() -> LockType:
|
||||
global _global_lock, is_multiprocess
|
||||
|
||||
if _global_lock is None:
|
||||
if is_multiprocess:
|
||||
_global_lock = manager.Lock()
|
||||
else:
|
||||
_global_lock = ThreadLock()
|
||||
|
||||
return _global_lock
|
||||
|
||||
def get_storage_lock() -> LockType:
|
||||
"""return storage lock for data consistency"""
|
||||
return _get_global_lock()
|
||||
|
||||
def get_scan_lock() -> LockType:
|
||||
"""return scan_progress lock for data consistency"""
|
||||
return get_storage_lock()
|
||||
|
||||
def get_shared_data() -> Dict[str, Any]:
|
||||
"""
|
||||
return shared data for all storage types
|
||||
create mult-process save share data only if need for better performance
|
||||
"""
|
||||
global _shared_data, is_multiprocess
|
||||
|
||||
if _shared_data is None:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if _shared_data is None:
|
||||
if is_multiprocess:
|
||||
_shared_data = manager.dict()
|
||||
else:
|
||||
_shared_data = {}
|
||||
|
||||
return _shared_data
|
||||
|
||||
def get_namespace_object(namespace: str) -> Any:
|
||||
"""Get an object for specific namespace"""
|
||||
global _namespace_objects, is_multiprocess
|
||||
|
||||
if _namespace_objects is None:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if _namespace_objects is None:
|
||||
_namespace_objects = {}
|
||||
|
||||
if namespace not in _namespace_objects:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if namespace not in _namespace_objects:
|
||||
if is_multiprocess:
|
||||
_namespace_objects[namespace] = manager.Value('O', None)
|
||||
else:
|
||||
_namespace_objects[namespace] = None
|
||||
|
||||
return _namespace_objects[namespace]
|
||||
|
||||
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||
"""get storage space for specific storage type(namespace)"""
|
||||
shared_data = get_shared_data()
|
||||
lock = _get_global_lock()
|
||||
|
||||
if namespace not in shared_data:
|
||||
with lock:
|
||||
if namespace not in shared_data:
|
||||
shared_data[namespace] = {}
|
||||
|
||||
return shared_data[namespace]
|
||||
|
||||
def get_scan_progress() -> Dict[str, Any]:
|
||||
"""get storage space for document scanning progress data"""
|
||||
return get_namespace_data('scan_progress')
|
@@ -266,13 +266,7 @@ class LightRAG:
|
||||
|
||||
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize manager if needed
|
||||
from lightrag.api.utils_api import manager, initialize_manager
|
||||
if manager is None:
|
||||
initialize_manager()
|
||||
logger.info("Initialized manager for single process mode")
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
||||
set_logger(self.log_file_path, self.log_level)
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
|
Reference in New Issue
Block a user