diff --git a/lightrag/__init__.py b/lightrag/__init__.py index e4cb3e63..382060f7 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.2.4" +__version__ = "1.2.5" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5df4f765..fd09a691 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -50,9 +50,6 @@ from .auth import auth_handler # This update allows the user to put a different.env file for each lightrag folder load_dotenv(".env", override=True) -# Read entity extraction cache config -enable_llm_cache = os.getenv("ENABLE_LLM_CACHE_FOR_EXTRACT", "false").lower() == "true" - # Initialize config parser config = configparser.ConfigParser() config.read("config.ini") @@ -144,23 +141,25 @@ def create_app(args): try: # Initialize database connections await rag.initialize_storages() - await initialize_pipeline_status() - # Auto scan documents if enabled - if args.auto_scan_at_startup: - # Check if a task is already running (with lock protection) - pipeline_status = await get_namespace_data("pipeline_status") - should_start_task = False - async with get_pipeline_status_lock(): - if not pipeline_status.get("busy", False): - should_start_task = True - # Only start the task if no other task is running - if should_start_task: - # Create background task - task = asyncio.create_task(run_scanning_process(rag, doc_manager)) - app.state.background_tasks.add(task) - task.add_done_callback(app.state.background_tasks.discard) - logger.info("Auto scan task started at startup.") + await initialize_pipeline_status() + pipeline_status = await get_namespace_data("pipeline_status") + + should_start_autoscan = False + async with get_pipeline_status_lock(): + # Auto scan documents if enabled + if args.auto_scan_at_startup: + if not pipeline_status.get("autoscanned", False): + pipeline_status["autoscanned"] = True + should_start_autoscan = True + + # Only run auto scan when no other process started it first + if should_start_autoscan: + # Create background task + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) + app.state.background_tasks.add(task) + task.add_done_callback(app.state.background_tasks.discard) + logger.info(f"Process {os.getpid()} auto scan task started at startup.") ASCIIColors.green("\nServer is ready to accept connections! 🚀\n") @@ -326,7 +325,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -355,7 +354,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -419,6 +418,7 @@ def create_app(args): "doc_status_storage": args.doc_status_storage, "graph_storage": args.graph_storage, "vector_storage": args.vector_storage, + "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, }, "update_status": update_status, } diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 3e51fa4d..c1666192 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -16,7 +16,11 @@ from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from ..utils_api import get_api_key_dependency, get_auth_dependency +from lightrag.api.utils_api import ( + get_api_key_dependency, + global_args, + get_auth_dependency, +) router = APIRouter( prefix="/documents", @@ -240,54 +244,93 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ) return False case ".pdf": - if not pm.is_installed("pypdf2"): # type: ignore - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("pypdf2"): # type: ignore + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" case ".docx": - if not pm.is_installed("python-docx"): # type: ignore - pm.install("docx") - from docx import Document # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter - docx_file = BytesIO(file) - doc = Document(docx_file) - content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("python-docx"): # type: ignore + pm.install("docx") + from docx import Document # type: ignore + from io import BytesIO + + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) case ".pptx": - if not pm.is_installed("python-pptx"): # type: ignore - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter - pptx_file = BytesIO(file) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("python-pptx"): # type: ignore + pm.install("pptx") + from pptx import Presentation # type: ignore + from io import BytesIO + + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" case ".xlsx": - if not pm.is_installed("openpyxl"): # type: ignore - pm.install("openpyxl") - from openpyxl import load_workbook # type: ignore - from io import BytesIO + if global_args["main_args"].document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter - xlsx_file = BytesIO(file) - wb = load_workbook(xlsx_file) - for sheet in wb: - content += f"Sheet: {sheet.title}\n" - for row in sheet.iter_rows(values_only=True): - content += ( - "\t".join( - str(cell) if cell is not None else "" for cell in row + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("openpyxl"): # type: ignore + pm.install("openpyxl") + from openpyxl import load_workbook # type: ignore + from io import BytesIO + + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += ( + "\t".join( + str(cell) if cell is not None else "" + for cell in row + ) + + "\n" ) - + "\n" - ) - content += "\n" + content += "\n" case _: logger.error( f"Unsupported file type: {file_path.name} (extension {ext})" diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 9688d073..37d7354e 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -11,7 +11,7 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import encode_string_by_tiktoken -from ..utils_api import ollama_server_infos +from lightrag.api.utils_api import ollama_server_infos # query mode according to query prefix (bypass is not LightRAG quer mode) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index dc467449..1f75db9c 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -18,6 +18,8 @@ from .auth import auth_handler # Load environment variables load_dotenv(override=True) +global_args = {"main_args": None} + class OllamaServerInfos: # Constants for emulated Ollama model information @@ -360,8 +362,17 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) + # Inject LLM cache configuration + args.enable_llm_cache_for_extract = get_env_value( + "ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool + ) + + # Select Document loading tool (DOCLING, DEFAULT) + args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT") + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name + global_args["main_args"] = args return args @@ -451,8 +462,10 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.history_turns}") ASCIIColors.white(" ├─ Cosine Threshold: ", end="") ASCIIColors.yellow(f"{args.cosine_threshold}") - ASCIIColors.white(" └─ Top-K: ", end="") + ASCIIColors.white(" ├─ Top-K: ", end="") ASCIIColors.yellow(f"{args.top_k}") + ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="") + ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}") # System Configuration ASCIIColors.magenta("\n💾 Storage Configuration:") diff --git a/lightrag/base.py b/lightrag/base.py index c84c7c62..86566787 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -127,6 +127,30 @@ class BaseVectorStorage(StorageNameSpace, ABC): async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity.""" + @abstractmethod + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + pass + + @abstractmethod + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + pass + @dataclass class BaseKVStorage(StorageNameSpace, ABC): diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 35b4cb58..84d43326 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -271,3 +271,67 @@ class ChromaVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error during prefix search in ChromaDB: {str(e)}") raise + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Query the collection for a single vector by ID + result = self._collection.get( + ids=[id], include=["metadatas", "embeddings", "documents"] + ) + + if not result or not result["ids"] or len(result["ids"]) == 0: + return None + + # Format the result to match the expected structure + return { + "id": result["ids"][0], + "vector": result["embeddings"][0], + "content": result["documents"][0], + **result["metadatas"][0], + } + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Query the collection for multiple vectors by IDs + result = self._collection.get( + ids=ids, include=["metadatas", "embeddings", "documents"] + ) + + if not result or not result["ids"] or len(result["ids"]) == 0: + return [] + + # Format the results to match the expected structure + return [ + { + "id": result["ids"][i], + "vector": result["embeddings"][i], + "content": result["documents"][i], + **result["metadatas"][i], + } + for i in range(len(result["ids"])) + ] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 6832b756..57b0cae0 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -394,3 +394,46 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'") return matching_records + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + # Find the Faiss internal ID for the custom ID + fid = self._find_faiss_id_by_custom_id(id) + if fid is None: + return None + + # Get the metadata for the found ID + metadata = self._id_to_meta.get(fid, {}) + if not metadata: + return None + + return {**metadata, "id": metadata.get("__id__")} + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + results = [] + for id in ids: + fid = self._find_faiss_id_by_custom_id(id) + if fid is not None: + metadata = self._id_to_meta.get(fid, {}) + if metadata: + results.append({**metadata, "id": metadata.get("__id__")}) + + return results diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 01c657fa..57a34ae5 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -15,6 +15,10 @@ from lightrag.utils import ( from .shared_storage import ( get_namespace_data, get_storage_lock, + get_data_init_lock, + get_update_flag, + set_all_update_flags, + clear_all_update_flags, try_initialize_namespace, ) @@ -27,21 +31,25 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._storage_lock = get_storage_lock() self._data = None + self._storage_lock = None + self.storage_updated = None async def initialize(self): """Initialize storage data""" - # check need_init must before get_namespace_data - need_init = try_initialize_namespace(self.namespace) - self._data = await get_namespace_data(self.namespace) - if need_init: - loaded_data = load_json(self._file_name) or {} - async with self._storage_lock: - self._data.update(loaded_data) - logger.info( - f"Loaded document status storage with {len(loaded_data)} records" - ) + self._storage_lock = get_storage_lock() + self.storage_updated = await get_update_flag(self.namespace) + async with get_data_init_lock(): + # check need_init must before get_namespace_data + need_init = await try_initialize_namespace(self.namespace) + self._data = await get_namespace_data(self.namespace) + if need_init: + loaded_data = load_json(self._file_name) or {} + async with self._storage_lock: + self._data.update(loaded_data) + logger.info( + f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records" + ) async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" @@ -87,18 +95,24 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - data_dict = ( - dict(self._data) if hasattr(self._data, "_getvalue") else self._data - ) - write_json(data_dict, self._file_name) + if self.storage_updated.value: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + logger.info( + f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}" + ) + write_json(data_dict, self._file_name) + await clear_all_update_flags(self.namespace) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - + logger.info(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: self._data.update(data) + await set_all_update_flags(self.namespace) + await self.index_done_callback() async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: @@ -109,9 +123,12 @@ class JsonDocStatusStorage(DocStatusStorage): async with self._storage_lock: for doc_id in doc_ids: self._data.pop(doc_id, None) + await set_all_update_flags(self.namespace) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" async with self._storage_lock: self._data.clear() + await set_all_update_flags(self.namespace) + await self.index_done_callback() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index c0b61a63..e7deaf15 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -13,6 +13,10 @@ from lightrag.utils import ( from .shared_storage import ( get_namespace_data, get_storage_lock, + get_data_init_lock, + get_update_flag, + set_all_update_flags, + clear_all_update_flags, try_initialize_namespace, ) @@ -23,26 +27,63 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._storage_lock = get_storage_lock() self._data = None + self._storage_lock = None + self.storage_updated = None async def initialize(self): """Initialize storage data""" - # check need_init must before get_namespace_data - need_init = try_initialize_namespace(self.namespace) - self._data = await get_namespace_data(self.namespace) - if need_init: - loaded_data = load_json(self._file_name) or {} - async with self._storage_lock: - self._data.update(loaded_data) - logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") + self._storage_lock = get_storage_lock() + self.storage_updated = await get_update_flag(self.namespace) + async with get_data_init_lock(): + # check need_init must before get_namespace_data + need_init = await try_initialize_namespace(self.namespace) + self._data = await get_namespace_data(self.namespace) + if need_init: + loaded_data = load_json(self._file_name) or {} + async with self._storage_lock: + self._data.update(loaded_data) + + # Calculate data count based on namespace + if self.namespace.endswith("cache"): + # For cache namespaces, sum the cache entries across all cache types + data_count = sum( + len(first_level_dict) + for first_level_dict in loaded_data.values() + if isinstance(first_level_dict, dict) + ) + else: + # For non-cache namespaces, use the original count method + data_count = len(loaded_data) + + logger.info( + f"Process {os.getpid()} KV load {self.namespace} with {data_count} records" + ) async def index_done_callback(self) -> None: async with self._storage_lock: - data_dict = ( - dict(self._data) if hasattr(self._data, "_getvalue") else self._data - ) - write_json(data_dict, self._file_name) + if self.storage_updated.value: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + + # Calculate data count based on namespace + if self.namespace.endswith("cache"): + # # For cache namespaces, sum the cache entries across all cache types + data_count = sum( + len(first_level_dict) + for first_level_dict in data_dict.values() + if isinstance(first_level_dict, dict) + ) + else: + # For non-cache namespaces, use the original count method + data_count = len(data_dict) + + logger.info( + f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}" + ) + write_json(data_dict, self._file_name) + await clear_all_update_flags(self.namespace) async def get_all(self) -> dict[str, Any]: """Get all data from storage @@ -73,15 +114,16 @@ class JsonKVStorage(BaseKVStorage): return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return + logger.info(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) + self._data.update(data) + await set_all_update_flags(self.namespace) async def delete(self, ids: list[str]) -> None: async with self._storage_lock: for doc_id in ids: self._data.pop(doc_id, None) + await set_all_update_flags(self.namespace) await self.index_done_callback() diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 8b82ddf1..4b4577ca 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -233,3 +233,57 @@ class MilvusVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error searching for records with prefix '{prefix}': {e}") return [] + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Query Milvus for a specific ID + result = self._client.query( + collection_name=self.namespace, + filter=f'id == "{id}"', + output_fields=list(self.meta_fields) + ["id"], + ) + + if not result or len(result) == 0: + return None + + return result[0] + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Prepare the ID filter expression + id_list = '", "'.join(ids) + filter_expr = f'id in ["{id_list}"]' + + # Query Milvus with the filter + result = self._client.query( + collection_name=self.namespace, + filter=filter_expr, + output_fields=list(self.meta_fields) + ["id"], + ) + + return result or [] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index a2d9e51f..7d43e4f4 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1073,6 +1073,59 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Search for the specific ID in MongoDB + result = await self._data.find_one({"_id": id}) + if result: + # Format the result to include id field expected by API + result_dict = dict(result) + if "_id" in result_dict and "id" not in result_dict: + result_dict["id"] = result_dict["_id"] + return result_dict + return None + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Query MongoDB for multiple IDs + cursor = self._data.find({"_id": {"$in": ids}}) + results = await cursor.to_list(length=None) + + # Format results to include id field expected by API + formatted_results = [] + for result in results: + result_dict = dict(result) + if "_id" in result_dict and "id" not in result_dict: + result_dict["id"] = result_dict["_id"] + formatted_results.append(result_dict) + + return formatted_results + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): collection_names = await db.list_collection_names() diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index c97aaa3a..4f739091 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -258,3 +258,33 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.debug(f"Found {len(matching_records)} records with prefix '{prefix}'") return matching_records + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + client = await self._get_client() + result = client.get([id]) + if result: + return result[0] + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + client = await self._get_client() + return client.get(ids) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index fec39138..d0c6c779 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,7 +3,7 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, List, Dict, final +from typing import Any, final, Optional import numpy as np import configparser @@ -15,6 +15,7 @@ from tenacity import ( retry_if_exception_type, ) +import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge @@ -37,6 +38,9 @@ config.read("config.ini", "utf-8") # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) +# Set neo4j logger level to ERROR to suppress warning logs +logging.getLogger("neo4j").setLevel(logging.ERROR) + @final @dataclass @@ -60,19 +64,25 @@ class Neo4JStorage(BaseGraphStorage): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=800), + config.get("neo4j", "connection_pool_size", fallback=50), ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=60.0), + config.get("neo4j", "connection_timeout", fallback=30.0), ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=60.0), + config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), + ), + ) + MAX_TRANSACTION_RETRY_TIME = float( + os.environ.get( + "NEO4J_MAX_TRANSACTION_RETRY_TIME", + config.get("neo4j", "max_transaction_retry_time", fallback=30.0), ), ) DATABASE = os.environ.get( @@ -85,6 +95,7 @@ class Neo4JStorage(BaseGraphStorage): max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, connection_timeout=CONNECTION_TIMEOUT, connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, + max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, ) # Try to connect to the database @@ -152,65 +163,84 @@ class Neo4JStorage(BaseGraphStorage): } async def close(self): + """Close the Neo4j driver and release all resources""" if self._driver: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): - if self._driver: - await self._driver.close() + """Ensure driver is closed when context manager exits""" + await self.close() async def index_done_callback(self) -> None: # Noe4J handles persistence automatically pass - async def _label_exists(self, label: str) -> bool: - """Check if a label exists in the Neo4j database.""" - query = "CALL db.labels() YIELD label RETURN label" - try: - async with self._driver.session(database=self._DATABASE) as session: - result = await session.run(query) - labels = [record["label"] for record in await result.data()] - return label in labels - except Exception as e: - logger.error(f"Error checking label existence: {e}") - return False - - async def _ensure_label(self, label: str) -> str: - """Ensure a label exists by validating it.""" - clean_label = label.strip('"') - if not await self._label_exists(clean_label): - logger.warning(f"Label '{clean_label}' does not exist in Neo4j") - return clean_label - async def has_node(self, node_id: str) -> bool: - entity_name_label = await self._ensure_label(node_id) - async with self._driver.session(database=self._DATABASE) as session: - query = ( - f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" - ) - result = await session.run(query) - single_result = await result.single() - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" - ) - return single_result["node_exists"] + """ + Check if a node with the given label exists in the database + + Args: + node_id: Label of the node to check + + Returns: + bool: True if node exists, False otherwise + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["node_exists"] + except Exception as e: + logger.error(f"Error checking node existence for {node_id}: {str(e)}") + await result.consume() # Ensure results are consumed even on error + raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + """ + Check if an edge exists between two nodes - async with self._driver.session(database=self._DATABASE) as session: - query = ( - f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " - "RETURN COUNT(r) > 0 AS edgeExists" - ) - result = await session.run(query) - single_result = await result.single() - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" - ) - return single_result["edgeExists"] + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + bool: True if edge exists, False otherwise + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = ( + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["edgeExists"] + except Exception as e: + logger.error( + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure results are consumed even on error + raise async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier. @@ -221,161 +251,258 @@ class Neo4JStorage(BaseGraphStorage): Returns: dict: Node properties if found None: If node not found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query """ - async with self._driver.session(database=self._DATABASE) as session: - entity_name_label = await self._ensure_label(node_id) - query = f"MATCH (n:`{entity_name_label}`) RETURN n" - result = await session.run(query) - record = await result.single() - if record: - node = record["n"] - node_dict = dict(node) - logger.debug( - f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" - ) - return node_dict - return None + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) + try: + records = await result.fetch( + 2 + ) # Get 2 records for duplication check + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{node_id}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + # Remove base label from labels list if it exists + if "labels" in node_dict: + node_dict["labels"] = [ + label + for label in node_dict["labels"] + if label != "base" + ] + logger.debug(f"Neo4j query node {query} return: {node_dict}") + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {node_id}: {str(e)}") + raise async def node_degree(self, node_id: str) -> int: - entity_name_label = node_id.strip('"') + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. - async with self._driver.session(database=self._DATABASE) as session: - query = f""" - MATCH (n:`{entity_name_label}`) - RETURN COUNT{{ (n)--() }} AS totalEdgeCount - """ - result = await session.run(query) - record = await result.single() - if record: - edge_count = record["totalEdgeCount"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" - ) - return edge_count - else: - return None + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + try: + record = await result.single() + + if not record: + logger.warning(f"No node found with label '{node_id}'") + return 0 + + degree = record["degree"] + logger.debug( + "Neo4j query node degree for {node_id} return: {degree}" + ) + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node degree for {node_id}: {str(e)}") + raise async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') - src_degree = await self.node_degree(entity_name_label_source) - trg_degree = await self.node_degree(entity_name_label_target) + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree degrees = int(src_degree) + int(trg_degree) - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}" - ) return degrees async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ try: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - - async with self._driver.session(database=self._DATABASE) as session: - query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties - LIMIT 1 """ - - result = await session.run(query) - record = await result.single() - if record: - try: - result = dict(record["edge_properties"]) - logger.info(f"Result: {result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in result: - result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" - ) - return result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + try: + records = await result.fetch(2) + + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." + ) + if records: + try: + edge_result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {edge_result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {source_node_id} and {target_node_id} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" + ) + return edge_result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {source_node_id} " + f"and {target_node_id}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}" + ) + # Return default edge properties when no edge found + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + finally: + await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - node_label = source_node_id.strip('"') + """Retrieves all edges (relationships) for a particular node identified by its label. + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + ValueError: If source_node_id is invalid + Exception: If there is an error executing the query """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ - query = f"""MATCH (n:`{node_label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - async with self._driver.session(database=self._DATABASE) as session: - results = await session.run(query) - edges = [] - async for record in results: - source_node = record["n"] - connected_node = record["connected"] + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + results = await session.run(query, entity_id=source_node_id) - source_label = ( - list(source_node.labels)[0] if source_node.labels else None - ) - target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None - ) + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] - if source_label and target_label: - edges.append((source_label, target_label)) + # Skip if either node is None + if not source_node or not connected_node: + continue - return edges + source_label = ( + source_node.get("entity_id") + if source_node.get("entity_id") + else None + ) + target_label = ( + connected_node.get("entity_id") + if connected_node.get("entity_id") + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges + except Exception as e: + logger.error( + f"Error getting edges for node {source_node_id}: {str(e)}" + ) + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise @retry( stop=stop_after_attempt(3), @@ -397,26 +524,47 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = await self._ensure_label(node_id) properties = node_data - - async def _do_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:`{label}`) - SET n += $properties - """ - await tx.run(query, properties=properties) - logger.debug( - f"Upserted node with label '{label}' and properties: {properties}" - ) + entity_type = properties["entity_type"] + entity_id = properties["entity_id"] + if "entity_id" not in properties: + raise ValueError("Neo4j: node properties must contain an 'entity_id' field") try: async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_upsert) + + async def execute_upsert(tx: AsyncManagedTransaction): + query = ( + """ + MERGE (n:base {entity_id: $properties.entity_id}) + SET n += $properties + SET n:`%s` + """ + % entity_type + ) + result = await tx.run(query, properties=properties) + logger.debug( + f"Upserted node with entity_id '{entity_id}' and properties: {properties}" + ) + await result.consume() # Ensure result is fully consumed + + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -434,34 +582,47 @@ class Neo4JStorage(BaseGraphStorage): ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. Args: source_node_id (str): Label of the source node (used as identifier) target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge + + Raises: + ValueError: If either source or target node does not exist or is not unique """ - source_label = await self._ensure_label(source_node_id) - target_label = await self._ensure_label(target_node_id) - edge_properties = edge_data - - async def _do_upsert_edge(tx: AsyncManagedTransaction): - query = f""" - MATCH (source:`{source_label}`) - WITH source - MATCH (target:`{target_label}`) - MERGE (source)-[r:DIRECTED]->(target) - SET r += $properties - RETURN r - """ - result = await tx.run(query, properties=edge_properties) - record = await result.single() - logger.debug( - f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" - ) - try: + edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_upsert_edge) + + async def execute_upsert(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id}) + WITH source + MATCH (target:base {entity_id: $target_entity_id}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + try: + records = await result.fetch(2) + if records: + logger.debug( + f"Upserted edge from '{source_node_id}' to '{target_node_id}'" + f"with properties: {edge_properties}" + ) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") raise @@ -470,199 +631,293 @@ class Neo4JStorage(BaseGraphStorage): print("Implemented but never called.") async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 3, + min_degree: int = 0, + inclusive: bool = False, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence (nodes containing the specified label string) - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes + 1. min_degree does not affect nodes directly connected to the matching nodes + 2. Label matching nodes take precedence + 3. Followed by nodes directly connected to the matching nodes + 4. Finally, the degree of the nodes Args: - node_label (str): String to match in node labels (will match any node containing this string in its label) - max_depth (int, optional): Maximum depth of the graph. Defaults to 5. + node_label: Label of the starting node + max_depth: Maximum depth of the subgraph + min_degree: Minimum degree of nodes to include. Defaults to 0 + inclusive: Do an inclusive search if true Returns: KnowledgeGraph: Complete connected subgraph for specified node """ - label = node_label.strip('"') - # Escape single quotes to prevent injection attacks - escaped_label = label.replace("'", "\\'") result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: try: - if label == "*": + if node_label == "*": main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() WITH n, count(r) AS degree + WHERE degree >= $min_degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect(n) AS nodes - MATCH (a)-[r]->(b) - WHERE a IN nodes AND b IN nodes - RETURN nodes, collect(DISTINCT r) AS relationships + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree}, ) else: - validate_query = f""" - MATCH (n) - WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}') - RETURN n LIMIT 1 - """ - validate_result = await session.run(validate_query) - if not await validate_result.single(): - logger.warning( - f"No nodes containing '{label}' in their labels found!" - ) - return result - # Main query uses partial matching - main_query = f""" + main_query = """ MATCH (start) - WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') + WHERE + CASE + WHEN $inclusive THEN start.entity_id CONTAINS $entity_id + ELSE start.entity_id = $entity_id + END WITH start - CALL apoc.path.subgraphAll(start, {{ - relationshipFilter: '>', + CALL apoc.path.subgraphAll(start, { + relationshipFilter: '', minLevel: 0, - maxLevel: {max_depth}, + maxLevel: $max_depth, bfs: true - }}) + }) YIELD nodes, relationships WITH start, nodes, relationships UNWIND nodes AS node OPTIONAL MATCH (node)-[r]-() - WITH node, count(r) AS degree, start, nodes, relationships, - CASE - WHEN id(node) = id(start) THEN 2 - WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 - ELSE 0 - END AS priority - ORDER BY priority DESC, degree DESC + WITH node, count(r) AS degree, start, nodes, relationships + WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree + ORDER BY + CASE + WHEN node = start THEN 3 + WHEN EXISTS((start)--(node)) THEN 2 + ELSE 1 + END DESC, + degree DESC LIMIT $max_nodes - WITH collect(node) AS filtered_nodes, nodes, relationships - RETURN filtered_nodes AS nodes, - [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships + WITH collect({node: node}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + { + "max_nodes": MAX_GRAPH_NODES, + "entity_id": node_label, + "inclusive": inclusive, + "max_depth": max_depth, + "min_degree": min_degree, + }, ) - record = await result_set.single() + try: + record = await result_set.single() - if record: - # Handle nodes (compatible with multi-label cases) - for node in record["nodes"]: - # Use node ID + label combination as unique identifier - node_id = node.id - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=list(node.labels), - properties=dict(node), + if record: + # Handle nodes (compatible with multi-label cases) + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[ + label + for label in node.labels + if label != "base" + ], + properties=dict(node), + ) ) - ) - seen_nodes.add(node_id) + seen_nodes.add(node_id) - # Handle relationships (including direction information) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) ) - ) - seen_edges.add(edge_id) + seen_edges.add(edge_id) - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) + logger.info( + f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges" + ) + finally: + await result_set.consume() # Ensure result set is consumed except neo4jExceptions.ClientError as e: - logger.error(f"APOC query failed: {str(e)}") - return await self._robust_fallback(label, max_depth) + logger.warning(f"APOC plugin error: {str(e)}") + if node_label != "*": + logger.warning( + "Neo4j: falling back to basic Cypher recursive search..." + ) + if inclusive: + logger.warning( + "Neo4j: inclusive search mode is not supported in recursive query, using exact matching" + ) + return await self._robust_fallback( + node_label, max_depth, min_degree + ) return result async def _robust_fallback( - self, label: str, max_depth: int - ) -> Dict[str, List[Dict]]: - """Enhanced fallback query solution""" - result = {"nodes": [], "edges": []} + self, node_label: str, max_depth: int, min_degree: int = 0 + ) -> KnowledgeGraph: + """ + Fallback implementation when APOC plugin is not available or incompatible. + This method implements the same functionality as get_knowledge_graph but uses + only basic Cypher queries and recursive traversal instead of APOC procedures. + """ + result = KnowledgeGraph() visited_nodes = set() visited_edges = set() - async def traverse(current_label: str, current_depth: int): + async def traverse( + node: KnowledgeGraphNode, + edge: Optional[KnowledgeGraphEdge], + current_depth: int, + ): + # Check traversal limits if current_depth > max_depth: + logger.debug(f"Reached max depth: {max_depth}") + return + if len(visited_nodes) >= MAX_GRAPH_NODES: + logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") return - # Get current node details - node = await self.get_node(current_label) - if not node: + # Check if node already visited + if node.id in visited_nodes: return - node_id = f"{current_label}" - if node_id in visited_nodes: - return - visited_nodes.add(node_id) + # Get all edges and target nodes + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + MATCH (a:base {entity_id: $entity_id})-[r]-(b) + WITH r, b, id(r) as edge_id, id(b) as target_id + RETURN r, b, edge_id, target_id + """ + results = await session.run(query, entity_id=node.id) - # Add node data (with complete labels) - node_data = {k: v for k, v in node.items()} - node_data["labels"] = [ - current_label - ] # Assume get_node method returns label information - result["nodes"].append(node_data) + # Get all records and release database connection + records = await results.fetch( + 1000 + ) # Max neighbour nodes we can handled + await results.consume() # Ensure results are consumed - # Get all outgoing and incoming edges - query = f""" - MATCH (a)-[r]-(b) - WHERE a:`{current_label}` OR b:`{current_label}` - RETURN a, r, b, - CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction - """ - async with self._driver.session(database=self._DATABASE) as session: - results = await session.run(query) - async for record in results: - # Handle edges + # Nodes not connected to start node need to check degree + if current_depth > 1 and len(records) < min_degree: + return + + # Add current node to result + result.nodes.append(node) + visited_nodes.add(node.id) + + # Add edge to result if it exists and not already added + if edge and edge.id not in visited_edges: + result.edges.append(edge) + visited_edges.add(edge.id) + + # Prepare nodes and edges for recursive processing + nodes_to_process = [] + for record in records: rel = record["r"] - edge_id = f"{rel.id}_{rel.type}" + edge_id = str(record["edge_id"]) if edge_id not in visited_edges: - edge_data = dict(rel) - edge_data.update( - { - "source": list(record["a"].labels)[0], - "target": list(record["b"].labels)[0], - "type": rel.type, - "direction": record["direction"], - } - ) - result["edges"].append(edge_data) - visited_edges.add(edge_id) + b_node = record["b"] + target_id = b_node.get("entity_id") - # Recursively traverse adjacent nodes - next_label = ( - list(record["b"].labels)[0] - if record["direction"] == "OUTGOING" - else list(record["a"].labels)[0] - ) - await traverse(next_label, current_depth + 1) + if target_id: # Only process if target node has entity_id + # Create KnowledgeGraphNode for target + target_node = KnowledgeGraphNode( + id=f"{target_id}", + labels=[ + label for label in b_node.labels if label != "base" + ], + properties=dict(b_node.properties), + ) + + # Create KnowledgeGraphEdge + target_edge = KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{node.id}", + target=f"{target_id}", + properties=dict(rel), + ) + + nodes_to_process.append((target_node, target_edge)) + else: + logger.warning( + f"Skipping edge {edge_id} due to missing labels on target node" + ) + + # Process nodes after releasing database connection + for target_node, target_edge in nodes_to_process: + await traverse(target_node, target_edge, current_depth + 1) + + # Get the starting node's data + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + RETURN id(n) as node_id, n + """ + node_result = await session.run(query, entity_id=node_label) + try: + node_record = await node_result.single() + if not node_record: + return result + + # Create initial KnowledgeGraphNode + start_node = KnowledgeGraphNode( + id=f"{node_record['n'].get('entity_id')}", + labels=[ + label for label in node_record["n"].labels if label != "base" + ], + properties=dict(node_record["n"].properties), + ) + finally: + await node_result.consume() # Ensure results are consumed + + # Start traversal with the initial node + await traverse(start_node, None, 0) - await traverse(label, 0) return result async def get_all_labels(self) -> list[str]: @@ -671,23 +926,28 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" # Method 2: Query compatible with older versions query = """ - MATCH (n) - WITH DISTINCT labels(n) AS node_labels - UNWIND node_labels AS label - RETURN DISTINCT label - ORDER BY label + MATCH (n) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label """ - result = await session.run(query) labels = [] - async for record in result: - labels.append(record["label"]) + try: + async for record in result: + labels.append(record["label"]) + finally: + await ( + result.consume() + ) # Ensure results are consumed even if processing fails return labels @retry( @@ -708,15 +968,15 @@ class Neo4JStorage(BaseGraphStorage): Args: node_id: The label of the node to delete """ - label = await self._ensure_label(node_id) async def _do_delete(tx: AsyncManagedTransaction): - query = f""" - MATCH (n:`{label}`) + query = """ + MATCH (n:base {entity_id: $entity_id}) DETACH DELETE n """ - await tx.run(query) - logger.debug(f"Deleted node with label '{label}'") + result = await tx.run(query, entity_id=node_id) + logger.debug(f"Deleted node with label '{node_id}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -765,16 +1025,17 @@ class Neo4JStorage(BaseGraphStorage): edges: List of edges to be deleted, each edge is a (source, target) tuple """ for source, target in edges: - source_label = await self._ensure_label(source) - target_label = await self._ensure_label(target) async def _do_delete_edge(tx: AsyncManagedTransaction): - query = f""" - MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + query = """ + MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) DELETE r """ - await tx.run(query) - logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + result = await tx.run( + query, source_entity_id=source, target_entity_id=target + ) + logger.debug(f"Deleted edge from '{source}' to '{target}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 04552e34..c42f0f76 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -531,6 +531,80 @@ class OracleVectorDBStorage(BaseVectorStorage): logger.error(f"Error searching records with prefix '{prefix}': {e}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Determine the table name based on namespace + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for ID lookup: {self.namespace}") + return None + + # Create the appropriate ID field name based on namespace + id_field = "entity_id" if "NODES" in table_name else "relation_id" + if "CHUNKS" in table_name: + id_field = "chunk_id" + + # Prepare and execute the query + query = f""" + SELECT * FROM {table_name} + WHERE {id_field} = :id AND workspace = :workspace + """ + params = {"id": id, "workspace": self.db.workspace} + + result = await self.db.query(query, params) + return result + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Determine the table name based on namespace + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") + return [] + + # Create the appropriate ID field name based on namespace + id_field = "entity_id" if "NODES" in table_name else "relation_id" + if "CHUNKS" in table_name: + id_field = "chunk_id" + + # Format the list of IDs for SQL IN clause + ids_list = ", ".join([f"'{id}'" for id in ids]) + + # Prepare and execute the query + query = f""" + SELECT * FROM {table_name} + WHERE {id_field} IN ({ids_list}) AND workspace = :workspace + """ + params = {"workspace": self.db.workspace} + + results = await self.db.query(query, params, multirows=True) + return results or [] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + @final @dataclass diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 1d525bdb..49d462f6 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -621,6 +621,60 @@ class PGVectorStorage(BaseVectorStorage): logger.error(f"Error during prefix search for '{prefix}': {e}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for ID lookup: {self.namespace}") + return None + + query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id=$2" + params = {"workspace": self.db.workspace, "id": id} + + try: + result = await self.db.query(query, params) + if result: + return dict(result) + return None + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") + return [] + + ids_str = ",".join([f"'{id}'" for id in ids]) + query = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" + params = {"workspace": self.db.workspace} + + try: + results = await self.db.query(query, params, multirows=True) + return [dict(record) for record in results] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + @final @dataclass diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index c8c154aa..736887a6 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -7,12 +7,18 @@ from typing import Any, Dict, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes -def direct_log(message, level="INFO"): +def direct_log(message, level="INFO", enable_output: bool = True): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. + + Args: + message: The message to log + level: Log level (default: "INFO") + enable_output: Whether to actually output the log (default: True) """ - print(f"{level}: {message}", file=sys.stderr, flush=True) + if enable_output: + print(f"{level}: {message}", file=sys.stderr, flush=True) T = TypeVar("T") @@ -32,55 +38,165 @@ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated _storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None +_graph_db_lock: Optional[LockType] = None +_data_init_lock: Optional[LockType] = None class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" - def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): + def __init__( + self, + lock: Union[ProcessLock, asyncio.Lock], + is_async: bool, + name: str = "unnamed", + enable_logging: bool = True, + ): self._lock = lock self._is_async = is_async + self._pid = os.getpid() # for debug only + self._name = name # for debug only + self._enable_logging = enable_logging # for debug only async def __aenter__(self) -> "UnifiedLock[T]": - if self._is_async: - await self._lock.acquire() - else: - self._lock.acquire() - return self + try: + direct_log( + f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", + enable_output=self._enable_logging, + ) + if self._is_async: + await self._lock.acquire() + else: + self._lock.acquire() + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", + enable_output=self._enable_logging, + ) + return self + except Exception as e: + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", + level="ERROR", + enable_output=self._enable_logging, + ) + raise async def __aexit__(self, exc_type, exc_val, exc_tb): - if self._is_async: - self._lock.release() - else: - self._lock.release() + try: + direct_log( + f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", + enable_output=self._enable_logging, + ) + if self._is_async: + self._lock.release() + else: + self._lock.release() + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", + enable_output=self._enable_logging, + ) + except Exception as e: + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", + level="ERROR", + enable_output=self._enable_logging, + ) + raise def __enter__(self) -> "UnifiedLock[T]": """For backward compatibility""" - if self._is_async: - raise RuntimeError("Use 'async with' for shared_storage lock") - self._lock.acquire() - return self + try: + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + direct_log( + f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", + enable_output=self._enable_logging, + ) + self._lock.acquire() + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", + enable_output=self._enable_logging, + ) + return self + except Exception as e: + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=self._enable_logging, + ) + raise def __exit__(self, exc_type, exc_val, exc_tb): """For backward compatibility""" - if self._is_async: - raise RuntimeError("Use 'async with' for shared_storage lock") - self._lock.release() + try: + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + direct_log( + f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", + enable_output=self._enable_logging, + ) + self._lock.release() + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", + enable_output=self._enable_logging, + ) + except Exception as e: + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=self._enable_logging, + ) + raise -def get_internal_lock() -> UnifiedLock: +def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess) + return UnifiedLock( + lock=_internal_lock, + is_async=not is_multiprocess, + name="internal_lock", + enable_logging=enable_logging, + ) -def get_storage_lock() -> UnifiedLock: +def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess) + return UnifiedLock( + lock=_storage_lock, + is_async=not is_multiprocess, + name="storage_lock", + enable_logging=enable_logging, + ) -def get_pipeline_status_lock() -> UnifiedLock: +def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess) + return UnifiedLock( + lock=_pipeline_status_lock, + is_async=not is_multiprocess, + name="pipeline_status_lock", + enable_logging=enable_logging, + ) + + +def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: + """return unified graph database lock for ensuring atomic operations""" + return UnifiedLock( + lock=_graph_db_lock, + is_async=not is_multiprocess, + name="graph_db_lock", + enable_logging=enable_logging, + ) + + +def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: + """return unified data initialization lock for ensuring atomic data initialization""" + return UnifiedLock( + lock=_data_init_lock, + is_async=not is_multiprocess, + name="data_init_lock", + enable_logging=enable_logging, + ) def initialize_share_data(workers: int = 1): @@ -108,6 +224,8 @@ def initialize_share_data(workers: int = 1): _storage_lock, \ _internal_lock, \ _pipeline_status_lock, \ + _graph_db_lock, \ + _data_init_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ @@ -120,14 +238,16 @@ def initialize_share_data(workers: int = 1): ) return - _manager = Manager() _workers = workers if workers > 1: is_multiprocess = True + _manager = Manager() _internal_lock = _manager.Lock() _storage_lock = _manager.Lock() _pipeline_status_lock = _manager.Lock() + _graph_db_lock = _manager.Lock() + _data_init_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() @@ -139,6 +259,8 @@ def initialize_share_data(workers: int = 1): _internal_lock = asyncio.Lock() _storage_lock = asyncio.Lock() _pipeline_status_lock = asyncio.Lock() + _graph_db_lock = asyncio.Lock() + _data_init_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} _update_flags = {} @@ -164,6 +286,7 @@ async def initialize_pipeline_status(): history_messages = _manager.list() if is_multiprocess else [] pipeline_namespace.update( { + "autoscanned": False, # Auto-scan started "busy": False, # Control concurrent processes "job_name": "Default Job", # Current job name (indexing files/indexing texts) "job_start": None, # Job start time @@ -200,7 +323,12 @@ async def get_update_flag(namespace: str): if is_multiprocess and _manager is not None: new_update_flag = _manager.Value("b", False) else: - new_update_flag = False + # Create a simple mutable object to store boolean value for compatibility with mutiprocess + class MutableBoolean: + def __init__(self, initial_value=False): + self.value = initial_value + + new_update_flag = MutableBoolean(False) _update_flags[namespace].append(new_update_flag) return new_update_flag @@ -220,7 +348,26 @@ async def set_all_update_flags(namespace: str): if is_multiprocess: _update_flags[namespace][i].value = True else: - _update_flags[namespace][i] = True + # Use .value attribute instead of direct assignment + _update_flags[namespace][i].value = True + + +async def clear_all_update_flags(namespace: str): + """Clear all update flag of namespace indicating all workers need to reload data from files""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + async with get_internal_lock(): + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[namespace])): + if is_multiprocess: + _update_flags[namespace][i].value = False + else: + # Use .value attribute instead of direct assignment + _update_flags[namespace][i].value = False async def get_all_update_flags_status() -> Dict[str, list]: @@ -247,7 +394,7 @@ async def get_all_update_flags_status() -> Dict[str, list]: return result -def try_initialize_namespace(namespace: str) -> bool: +async def try_initialize_namespace(namespace: str) -> bool: """ Returns True if the current worker(process) gets initialization permission for loading data later. The worker does not get the permission is prohibited to load data from files. @@ -257,15 +404,17 @@ def try_initialize_namespace(namespace: str) -> bool: if _init_flags is None: raise ValueError("Try to create nanmespace before Shared-Data is initialized") - if namespace not in _init_flags: - _init_flags[namespace] = True + async with get_internal_lock(): + if namespace not in _init_flags: + _init_flags[namespace] = True + direct_log( + f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + ) + return True direct_log( - f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" ) - return True - direct_log( - f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" - ) + return False @@ -304,6 +453,8 @@ def finalize_share_data(): _storage_lock, \ _internal_lock, \ _pipeline_status_lock, \ + _graph_db_lock, \ + _data_init_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ @@ -369,6 +520,8 @@ def finalize_share_data(): _storage_lock = None _internal_lock = None _pipeline_status_lock = None + _graph_db_lock = None + _data_init_lock = None _update_flags = None direct_log(f"Process {os.getpid()} storage data finalization complete") diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 9d807798..0982c914 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -465,6 +465,100 @@ class TiDBVectorDBStorage(BaseVectorStorage): logger.error(f"Error searching records with prefix '{prefix}': {e}") return [] + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Determine which table to query based on namespace + if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: + sql_template = """ + SELECT entity_id as id, name as entity_name, entity_type, description, content + FROM LIGHTRAG_GRAPH_NODES + WHERE entity_id = :entity_id AND workspace = :workspace + """ + params = {"entity_id": id, "workspace": self.db.workspace} + elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: + sql_template = """ + SELECT relation_id as id, source_name as src_id, target_name as tgt_id, + keywords, description, content + FROM LIGHTRAG_GRAPH_EDGES + WHERE relation_id = :relation_id AND workspace = :workspace + """ + params = {"relation_id": id, "workspace": self.db.workspace} + elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: + sql_template = """ + SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE chunk_id = :chunk_id AND workspace = :workspace + """ + params = {"chunk_id": id, "workspace": self.db.workspace} + else: + logger.warning( + f"Namespace {self.namespace} not supported for get_by_id" + ) + return None + + result = await self.db.query(sql_template, params=params) + return result + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Format IDs for SQL IN clause + ids_str = ", ".join([f"'{id}'" for id in ids]) + + # Determine which table to query based on namespace + if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: + sql_template = f""" + SELECT entity_id as id, name as entity_name, entity_type, description, content + FROM LIGHTRAG_GRAPH_NODES + WHERE entity_id IN ({ids_str}) AND workspace = :workspace + """ + elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: + sql_template = f""" + SELECT relation_id as id, source_name as src_id, target_name as tgt_id, + keywords, description, content + FROM LIGHTRAG_GRAPH_EDGES + WHERE relation_id IN ({ids_str}) AND workspace = :workspace + """ + elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: + sql_template = f""" + SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE chunk_id IN ({ids_str}) AND workspace = :workspace + """ + else: + logger.warning( + f"Namespace {self.namespace} not supported for get_by_ids" + ) + return [] + + params = {"workspace": self.db.workspace} + results = await self.db.query(sql_template, params=params, multirows=True) + return results if results else [] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + @final @dataclass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7fb24eee..3a5e4e84 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -30,11 +30,10 @@ from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, extract_entities, - extract_keywords_only, kg_query, - kg_query_with_keywords, mix_kg_vector_query, naive_query, + query_with_keywords, ) from .prompt import GRAPH_FIELD_SEP, PROMPTS from .utils import ( @@ -45,6 +44,9 @@ from .utils import ( encode_string_by_tiktoken, lazy_external_import, limit_async_func_call, + get_content_summary, + clean_text, + check_storage_env_vars, logger, ) from .types import KnowledgeGraph @@ -309,7 +311,7 @@ class LightRAG: # Verify storage implementation compatibility verify_storage_implementation(storage_type, storage_name) # Check environment variables - # self.check_storage_env_vars(storage_name) + check_storage_env_vars(storage_name) # Ensure vector_db_storage_cls_kwargs has required fields self.vector_db_storage_cls_kwargs = { @@ -354,6 +356,9 @@ class LightRAG: namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), + global_config=asdict( + self + ), # Add global_config to ensure cache works properly embedding_func=self.embedding_func, ) @@ -404,18 +409,8 @@ class LightRAG: embedding_func=None, ) - if self.llm_response_cache and hasattr( - self.llm_response_cache, "global_config" - ): - hashing_kv = self.llm_response_cache - else: - hashing_kv = self.key_string_value_json_storage_cls( # type: ignore - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ) + # Directly use llm_response_cache, don't create a new object + hashing_kv = self.llm_response_cache self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -543,11 +538,6 @@ class LightRAG: storage_class = lazy_external_import(import_path, storage_name) return storage_class - @staticmethod - def clean_text(text: str) -> str: - """Clean text by removing null bytes (0x00) and whitespace""" - return text.strip().replace("\x00", "") - def insert( self, input: str | list[str], @@ -590,6 +580,7 @@ class LightRAG: split_by_character, split_by_character_only ) + # TODO: deprecated, use insert instead def insert_custom_chunks( self, full_text: str, @@ -601,14 +592,15 @@ class LightRAG: self.ainsert_custom_chunks(full_text, text_chunks, doc_id) ) + # TODO: deprecated, use ainsert instead async def ainsert_custom_chunks( self, full_text: str, text_chunks: list[str], doc_id: str | None = None ) -> None: update_storage = False try: # Clean input texts - full_text = self.clean_text(full_text) - text_chunks = [self.clean_text(chunk) for chunk in text_chunks] + full_text = clean_text(full_text) + text_chunks = [clean_text(chunk) for chunk in text_chunks] # Process cleaned texts if doc_id is None: @@ -687,7 +679,7 @@ class LightRAG: contents = {id_: doc for id_, doc in zip(ids, input)} else: # Clean input text and remove duplicates - input = list(set(self.clean_text(doc) for doc in input)) + input = list(set(clean_text(doc) for doc in input)) # Generate contents dict of MD5 hash IDs and documents contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input} @@ -703,7 +695,7 @@ class LightRAG: new_docs: dict[str, Any] = { id_: { "content": content, - "content_summary": self._get_content_summary(content), + "content_summary": get_content_summary(content), "content_length": len(content), "status": DocStatus.PENDING, "created_at": datetime.now().isoformat(), @@ -892,7 +884,9 @@ class LightRAG: self.chunks_vdb.upsert(chunks) ) entity_relation_task = asyncio.create_task( - self._process_entity_relation_graph(chunks) + self._process_entity_relation_graph( + chunks, pipeline_status, pipeline_status_lock + ) ) full_docs_task = asyncio.create_task( self.full_docs.upsert( @@ -1007,21 +1001,27 @@ class LightRAG: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: + async def _process_entity_relation_graph( + self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None + ) -> None: try: await extract_entities( chunk, knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, - llm_response_cache=self.llm_response_cache, global_config=asdict(self), + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, + llm_response_cache=self.llm_response_cache, ) except Exception as e: logger.error("Failed to extract entities and relationships") raise e - async def _insert_done(self) -> None: + async def _insert_done( + self, pipeline_status=None, pipeline_status_lock=None + ) -> None: tasks = [ cast(StorageNameSpace, storage_inst).index_done_callback() for storage_inst in [ # type: ignore @@ -1040,12 +1040,10 @@ class LightRAG: log_message = "All Insert done" logger.info(log_message) - # 获取 pipeline_status 并更新 latest_message 和 history_messages - from lightrag.kg.shared_storage import get_namespace_data - - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) def insert_custom_kg( self, custom_kg: dict[str, Any], full_doc_id: str = None @@ -1062,7 +1060,7 @@ class LightRAG: all_chunks_data: dict[str, dict[str, str]] = {} chunk_to_source_map: dict[str, str] = {} for chunk_data in custom_kg.get("chunks", []): - chunk_content = self.clean_text(chunk_data["content"]) + chunk_content = clean_text(chunk_data["content"]) source_id = chunk_data["source_id"] tokens = len( encode_string_by_tiktoken( @@ -1260,16 +1258,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) elif param.mode == "naive": @@ -1279,16 +1268,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) elif param.mode == "mix": @@ -1301,16 +1281,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) else: @@ -1322,8 +1293,17 @@ class LightRAG: self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ - 1. Extract keywords from the 'query' using new function in operate.py. - 2. Then run the standard aquery() flow with the final prompt (formatted_question). + Query with separate keyword extraction step. + + This method extracts keywords from the query first, then uses them for the query. + + Args: + query: User query + prompt: Additional prompt for the query + param: Query parameters + + Returns: + Query response """ loop = always_get_an_event_loop() return loop.run_until_complete( @@ -1334,100 +1314,29 @@ class LightRAG: self, query: str, prompt: str, param: QueryParam = QueryParam() ) -> str | AsyncIterator[str]: """ - 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. - 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. + Async version of query_with_separate_keyword_extraction. + + Args: + query: User query + prompt: Additional prompt for the query + param: Query parameters + + Returns: + Query response or async iterator """ - # --------------------- - # STEP 1: Keyword Extraction - # --------------------- - hl_keywords, ll_keywords = await extract_keywords_only( - text=query, + response = await query_with_keywords( + query=query, + prompt=prompt, param=param, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entities_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + chunks_vdb=self.chunks_vdb, + text_chunks_db=self.text_chunks, global_config=asdict(self), - hashing_kv=self.llm_response_cache - or self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, ) - param.hl_keywords = hl_keywords - param.ll_keywords = ll_keywords - - # --------------------- - # STEP 2: Final Query Logic - # --------------------- - - # Create a new string with the prompt and the keywords - ll_keywords_str = ", ".join(ll_keywords) - hl_keywords_str = ", ".join(hl_keywords) - formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" - - if param.mode in ["local", "global", "hybrid"]: - response = await kg_query_with_keywords( - formatted_question, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.text_chunks, - param, - asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), - ) - elif param.mode == "naive": - response = await naive_query( - formatted_question, - self.chunks_vdb, - self.text_chunks, - param, - asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), - ) - elif param.mode == "mix": - response = await mix_kg_vector_query( - formatted_question, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.chunks_vdb, - self.text_chunks, - param, - asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), - ) - else: - raise ValueError(f"Unknown mode {param.mode}") - await self._query_done() return response @@ -1525,21 +1434,6 @@ class LightRAG: ] ) - def _get_content_summary(self, content: str, max_length: int = 100) -> str: - """Get summary of document content - - Args: - content: Original document content - max_length: Maximum length of summary - - Returns: - Truncated content with ellipsis if needed - """ - content = content.strip() - if len(content) <= max_length: - return content - return content[:max_length] + "..." - async def get_processing_status(self) -> dict[str, int]: """Get current document processing status counts @@ -1816,19 +1710,7 @@ class LightRAG: async def get_entity_info( self, entity_name: str, include_vector_data: bool = False ) -> dict[str, str | None | dict[str, str]]: - """Get detailed information of an entity - - Args: - entity_name: Entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - - Returns: - dict: A dictionary containing entity information, including: - - entity_name: Entity name - - source_id: Source document ID - - graph_data: Complete node data from the graph database - - vector_data: (optional) Data from the vector database - """ + """Get detailed information of an entity""" # Get information from the graph node_data = await self.chunk_entity_relation_graph.get_node(entity_name) @@ -1843,29 +1725,15 @@ class LightRAG: # Optional: Get vector database information if include_vector_data: entity_id = compute_mdhash_id(entity_name, prefix="ent-") - vector_data = self.entities_vdb._client.get([entity_id]) - result["vector_data"] = vector_data[0] if vector_data else None + vector_data = await self.entities_vdb.get_by_id(entity_id) + result["vector_data"] = vector_data return result async def get_relation_info( self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ) -> dict[str, str | None | dict[str, str]]: - """Get detailed information of a relationship - - Args: - src_entity: Source entity name (no need for quotes) - tgt_entity: Target entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - - Returns: - dict: A dictionary containing relationship information, including: - - src_entity: Source entity name - - tgt_entity: Target entity name - - source_id: Source document ID - - graph_data: Complete edge data from the graph database - - vector_data: (optional) Data from the vector database - """ + """Get detailed information of a relationship""" # Get information from the graph edge_data = await self.chunk_entity_relation_graph.get_edge( @@ -1883,8 +1751,8 @@ class LightRAG: # Optional: Get vector database information if include_vector_data: rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-") - vector_data = self.relationships_vdb._client.get([rel_id]) - result["vector_data"] = vector_data[0] if vector_data else None + vector_data = await self.relationships_vdb.get_by_id(rel_id) + result["vector_data"] = vector_data return result @@ -2682,6 +2550,12 @@ class LightRAG: # 9. Delete source entities for entity_name in source_entities: + if entity_name == target_entity: + logger.info( + f"Skipping deletion of '{entity_name}' as it's also the target entity" + ) + continue + # Delete entity node from knowledge graph await self.chunk_entity_relation_graph.delete_node(entity_name) diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 84e45cfb..3405d29e 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -55,6 +55,7 @@ async def azure_openai_complete_if_cache( openai_async_client = AsyncAzureOpenAI( azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_deployment=model, api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) @@ -136,6 +137,7 @@ async def azure_openai_embed( openai_async_client = AsyncAzureOpenAI( azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_deployment=model, api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) diff --git a/lightrag/operate.py b/lightrag/operate.py index 5e90a77b..1815f308 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import json import re +import os from typing import Any, AsyncIterator from collections import Counter, defaultdict @@ -140,18 +141,36 @@ async def _handle_single_entity_extraction( ): if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None - # add this record as a node in the G + + # Clean and validate entity name entity_name = clean_str(record_attributes[1]).strip('"') if not entity_name.strip(): + logger.warning( + f"Entity extraction error: empty entity name in: {record_attributes}" + ) return None + + # Clean and validate entity type entity_type = clean_str(record_attributes[2]).strip('"') + if not entity_type.strip() or entity_type.startswith('("'): + logger.warning( + f"Entity extraction error: invalid entity type in: {record_attributes}" + ) + return None + + # Clean and validate description entity_description = clean_str(record_attributes[3]).strip('"') - entity_source_id = chunk_key + if not entity_description.strip(): + logger.warning( + f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'" + ) + return None + return dict( entity_name=entity_name, entity_type=entity_type, description=entity_description, - source_id=entity_source_id, + source_id=chunk_key, metadata={"created_at": time.time()}, ) @@ -220,6 +239,7 @@ async def _merge_nodes_then_upsert( entity_name, description, global_config ) node_data = dict( + entity_id=entity_name, entity_type=entity_type, description=description, source_id=source_id, @@ -301,6 +321,7 @@ async def _merge_edges_then_upsert( await knowledge_graph_inst.upsert_node( need_insert_id, node_data={ + "entity_id": need_insert_id, "source_id": source_id, "description": description, "entity_type": "UNKNOWN", @@ -337,11 +358,10 @@ async def extract_entities( entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict[str, str], + pipeline_status: dict = None, + pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ) -> None: - from lightrag.kg.shared_storage import get_namespace_data - - pipeline_status = await get_namespace_data("pipeline_status") use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ @@ -400,6 +420,7 @@ async def extract_entities( else: _prompt = input_text + # TODO: add cache_type="extract" arg_hash = compute_args_hash(_prompt) cached_return, _1, _2, _3 = await handle_cache( llm_response_cache, @@ -407,7 +428,6 @@ async def extract_entities( _prompt, "default", cache_type="extract", - force_llm_cache=True, ) if cached_return: logger.debug(f"Found cache for {arg_hash}") @@ -436,47 +456,22 @@ async def extract_entities( else: return await use_llm_func(input_text) - async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): - """ "Prpocess a single chunk + async def _process_extraction_result(result: str, chunk_key: str): + """Process a single extraction result (either initial or gleaning) Args: - chunk_key_dp (tuple[str, TextChunkSchema]): - ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + result (str): The extraction result to process + chunk_key (str): The chunk key for source tracking + Returns: + tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships """ - nonlocal processed_chunks - chunk_key = chunk_key_dp[0] - chunk_dp = chunk_key_dp[1] - content = chunk_dp["content"] - # hint_prompt = entity_extract_prompt.format(**context_base, input_text=content) - hint_prompt = entity_extract_prompt.format( - **context_base, input_text="{input_text}" - ).format(**context_base, input_text=content) - - final_result = await _user_llm_func_with_cache(hint_prompt) - history = pack_user_ass_to_openai_messages(hint_prompt, final_result) - for now_glean_index in range(entity_extract_max_gleaning): - glean_result = await _user_llm_func_with_cache( - continue_prompt, history_messages=history - ) - - history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) - final_result += glean_result - if now_glean_index == entity_extract_max_gleaning - 1: - break - - if_loop_result: str = await _user_llm_func_with_cache( - if_loop_prompt, history_messages=history - ) - if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() - if if_loop_result != "yes": - break + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) records = split_string_by_multi_markers( - final_result, + result, [context_base["record_delimiter"], context_base["completion_delimiter"]], ) - maybe_nodes = defaultdict(list) - maybe_edges = defaultdict(list) for record in records: record = re.search(r"\((.*)\)", record) if record is None: @@ -485,6 +480,7 @@ async def extract_entities( record_attributes = split_string_by_multi_markers( record, [context_base["tuple_delimiter"]] ) + if_entities = await _handle_single_entity_extraction( record_attributes, chunk_key ) @@ -499,13 +495,71 @@ async def extract_entities( maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append( if_relation ) + + return maybe_nodes, maybe_edges + + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """Process a single chunk + Args: + chunk_key_dp (tuple[str, TextChunkSchema]): + ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + """ + nonlocal processed_chunks + chunk_key = chunk_key_dp[0] + chunk_dp = chunk_key_dp[1] + content = chunk_dp["content"] + + # Get initial extraction + hint_prompt = entity_extract_prompt.format( + **context_base, input_text="{input_text}" + ).format(**context_base, input_text=content) + + final_result = await _user_llm_func_with_cache(hint_prompt) + history = pack_user_ass_to_openai_messages(hint_prompt, final_result) + + # Process initial extraction + maybe_nodes, maybe_edges = await _process_extraction_result( + final_result, chunk_key + ) + + # Process additional gleaning results + for now_glean_index in range(entity_extract_max_gleaning): + glean_result = await _user_llm_func_with_cache( + continue_prompt, history_messages=history + ) + + history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) + + # Process gleaning result separately + glean_nodes, glean_edges = await _process_extraction_result( + glean_result, chunk_key + ) + + # Merge results + for entity_name, entities in glean_nodes.items(): + maybe_nodes[entity_name].extend(entities) + for edge_key, edges in glean_edges.items(): + maybe_edges[edge_key].extend(edges) + + if now_glean_index == entity_extract_max_gleaning - 1: + break + + if_loop_result: str = await _user_llm_func_with_cache( + if_loop_prompt, history_messages=history + ) + if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() + if if_loop_result != "yes": + break + processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return dict(maybe_nodes), dict(maybe_edges) tasks = [_process_single_content(c) for c in ordered_chunks] @@ -519,42 +573,58 @@ async def extract_entities( for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) - all_entities_data = await asyncio.gather( - *[ - _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ) + from .kg.shared_storage import get_graph_db_lock - all_relationships_data = await asyncio.gather( - *[ - _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) - for k, v in maybe_edges.items() - ] - ) + graph_db_lock = get_graph_db_lock(enable_logging=False) + + # Ensure that nodes and edges are merged and upserted atomically + async with graph_db_lock: + all_entities_data = await asyncio.gather( + *[ + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ) + + all_relationships_data = await asyncio.gather( + *[ + _merge_edges_then_upsert( + k[0], k[1], v, knowledge_graph_inst, global_config + ) + for k, v in maybe_edges.items() + ] + ) if not (all_entities_data or all_relationships_data): log_message = "Didn't extract any entities and relationships." logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return if not all_entities_data: log_message = "Didn't extract any entities" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) if not all_relationships_data: log_message = "Didn't extract any relationships" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) verbose_debug( f"New entities:{all_entities_data}, relationships:{all_relationships_data}" ) @@ -1020,6 +1090,7 @@ async def _build_query_context( text_chunks_db: BaseKVStorage, query_param: QueryParam, ): + logger.info(f"Process {os.getpid()} buidling query context...") if query_param.mode == "local": entities_context, relations_context, text_units_context = await _get_node_data( ll_keywords, @@ -1845,3 +1916,90 @@ async def kg_query_with_keywords( ) return response + + +async def query_with_keywords( + query: str, + prompt: str, + param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, +) -> str | AsyncIterator[str]: + """ + Extract keywords from the query and then use them for retrieving information. + + 1. Extracts high-level and low-level keywords from the query + 2. Formats the query with the extracted keywords and prompt + 3. Uses the appropriate query method based on param.mode + + Args: + query: The user's query + prompt: Additional prompt to prepend to the query + param: Query parameters + knowledge_graph_inst: Knowledge graph storage + entities_vdb: Entities vector database + relationships_vdb: Relationships vector database + chunks_vdb: Document chunks vector database + text_chunks_db: Text chunks storage + global_config: Global configuration + hashing_kv: Cache storage + + Returns: + Query response or async iterator + """ + # Extract keywords + hl_keywords, ll_keywords = await extract_keywords_only( + text=query, + param=param, + global_config=global_config, + hashing_kv=hashing_kv, + ) + + param.hl_keywords = hl_keywords + param.ll_keywords = ll_keywords + + # Create a new string with the prompt and the keywords + ll_keywords_str = ", ".join(ll_keywords) + hl_keywords_str = ", ".join(hl_keywords) + formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" + + # Use appropriate query method based on mode + if param.mode in ["local", "global", "hybrid"]: + return await kg_query_with_keywords( + formatted_question, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + param, + global_config, + hashing_kv=hashing_kv, + ) + elif param.mode == "naive": + return await naive_query( + formatted_question, + chunks_vdb, + text_chunks_db, + param, + global_config, + hashing_kv=hashing_kv, + ) + elif param.mode == "mix": + return await mix_kg_vector_query( + formatted_question, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + chunks_vdb, + text_chunks_db, + param, + global_config, + hashing_kv=hashing_kv, + ) + else: + raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 1486ccf8..f81cd441 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -236,7 +236,7 @@ Given the query and conversation history, list both high-level and low-level key ---Instructions--- - Consider both the current query and relevant conversation history when extracting keywords -- Output the keywords in JSON format +- Output the keywords in JSON format, it will be parsed by a JSON parser, do not add any extra content in output - The JSON should have two keys: - "high_level_keywords" for overarching concepts or themes - "low_level_keywords" for specific entities or details diff --git a/lightrag/utils.py b/lightrag/utils.py index bb1d6fae..b8f00c5d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -633,15 +633,15 @@ async def handle_cache( prompt, mode="default", cache_type=None, - force_llm_cache=False, ): """Generic cache handling function""" - if hashing_kv is None or not ( - force_llm_cache or hashing_kv.global_config.get("enable_llm_cache") - ): + if hashing_kv is None: return None, None, None, None - if mode != "default": + if mode != "default": # handle cache for all type of query + if not hashing_kv.global_config.get("enable_llm_cache"): + return None, None, None, None + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", @@ -651,8 +651,7 @@ async def handle_cache( use_llm_check = embedding_cache_config.get("use_llm_check", False) quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache + if is_embedding_cache_enabled: # Use embedding simularity to match cache current_embedding = await hashing_kv.embedding_func([prompt]) llm_model_func = hashing_kv.global_config.get("llm_model_func") quantized, min_val, max_val = quantize_embedding(current_embedding[0]) @@ -667,24 +666,29 @@ async def handle_cache( cache_type=cache_type, ) if best_cached_response is not None: - logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})") + logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})") return best_cached_response, None, None, None else: # if caching keyword embedding is enabled, return the quantized embedding for saving it latter - logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})") + logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})") return None, quantized, min_val, max_val - # For default mode or is_embedding_cache_enabled is False, use regular cache - # default mode is for extract_entities or naive query + else: # handle cache for entity extraction + if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): + return None, None, None, None + + # Here is the conditions of code reaching this point: + # 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled + # 2. Entity extract: enable_llm_cache_for_entity_extract is True if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: mode_cache = await hashing_kv.get_by_id(mode) or {} if args_hash in mode_cache: - logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") + logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") return mode_cache[args_hash]["return"], None, None, None - logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") + logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") return None, None, None, None @@ -701,9 +705,22 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): - if hashing_kv is None or hasattr(cache_data.content, "__aiter__"): + """Save data to cache, with improved handling for streaming responses and duplicate content. + + Args: + hashing_kv: The key-value storage for caching + cache_data: The cache data to save + """ + # Skip if storage is None or content is a streaming response + if hashing_kv is None or not cache_data.content: return + # If content is a streaming response, don't cache it + if hasattr(cache_data.content, "__aiter__"): + logger.debug("Streaming response detected, skipping cache") + return + + # Get existing cache data if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = ( await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash) @@ -712,6 +729,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): else: mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + # Check if we already have identical content cached + if cache_data.args_hash in mode_cache: + existing_content = mode_cache[cache_data.args_hash].get("return") + if existing_content == cache_data.content: + logger.info( + f"Cache content unchanged for {cache_data.args_hash}, skipping update" + ) + return + + # Update cache with new content mode_cache[cache_data.args_hash] = { "return": cache_data.content, "cache_type": cache_data.cache_type, @@ -726,6 +753,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "original_prompt": cache_data.prompt, } + # Only upsert if there's actual new content await hashing_kv.upsert({cache_data.mode: mode_cache}) @@ -862,3 +890,52 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any return cls(*args, **kwargs) return import_class + + +def get_content_summary(content: str, max_length: int = 100) -> str: + """Get summary of document content + + Args: + content: Original document content + max_length: Maximum length of summary + + Returns: + Truncated content with ellipsis if needed + """ + content = content.strip() + if len(content) <= max_length: + return content + return content[:max_length] + "..." + + +def clean_text(text: str) -> str: + """Clean text by removing null bytes (0x00) and whitespace + + Args: + text: Input text to clean + + Returns: + Cleaned text + """ + return text.strip().replace("\x00", "") + + +def check_storage_env_vars(storage_name: str) -> None: + """Check if all required environment variables for storage implementation exist + + Args: + storage_name: Storage implementation name + + Raises: + ValueError: If required environment variables are missing + """ + from lightrag.kg import STORAGE_ENV_REQUIREMENTS + + required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + raise ValueError( + f"Storage implementation '{storage_name}' requires the following " + f"environment variables: {', '.join(missing_vars)}" + ) diff --git a/lightrag_webui/bun.lock b/lightrag_webui/bun.lock index 6157e38c..3ca0d887 100644 --- a/lightrag_webui/bun.lock +++ b/lightrag_webui/bun.lock @@ -34,11 +34,13 @@ "cmdk": "^1.0.4", "graphology": "^0.26.0", "graphology-generators": "^0.11.2", + "i18next": "^24.2.2", "lucide-react": "^0.475.0", "minisearch": "^7.1.2", "react": "^19.0.0", "react-dom": "^19.0.0", "react-dropzone": "^14.3.6", + "react-i18next": "^15.4.1", "react-markdown": "^9.1.0", "react-number-format": "^5.4.3", "react-syntax-highlighter": "^15.6.1", @@ -765,8 +767,12 @@ "hoist-non-react-statics": ["hoist-non-react-statics@3.3.2", "", { "dependencies": { "react-is": "^16.7.0" } }, "sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw=="], + "html-parse-stringify": ["html-parse-stringify@3.0.1", "", { "dependencies": { "void-elements": "3.1.0" } }, "sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg=="], + "html-url-attributes": ["html-url-attributes@3.0.1", "", {}, "sha512-ol6UPyBWqsrO6EJySPz2O7ZSr856WDrEzM5zMqp+FJJLGMW35cLYmmZnl0vztAZxRUoNZJFTCohfjuIJ8I4QBQ=="], + "i18next": ["i18next@24.2.2", "", { "dependencies": { "@babel/runtime": "^7.23.2" }, "peerDependencies": { "typescript": "^5" }, "optionalPeers": ["typescript"] }, "sha512-NE6i86lBCKRYZa5TaUDkU5S4HFgLIEJRLr3Whf2psgaxBleQ2LC1YW1Vc+SCgkAW7VEzndT6al6+CzegSUHcTQ=="], + "ignore": ["ignore@5.3.2", "", {}, "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g=="], "import-fresh": ["import-fresh@3.3.1", "", { "dependencies": { "parent-module": "^1.0.0", "resolve-from": "^4.0.0" } }, "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ=="], @@ -1093,6 +1099,8 @@ "react-dropzone": ["react-dropzone@14.3.6", "", { "dependencies": { "attr-accept": "^2.2.4", "file-selector": "^2.1.0", "prop-types": "^15.8.1" }, "peerDependencies": { "react": ">= 16.8 || 18.0.0" } }, "sha512-U792j+x0rcwH/U/Slv/OBNU/LGFYbDLHKKiJoPhNaOianayZevCt4Y5S0CraPssH/6/wT6xhKDfzdXUgCBS0HQ=="], + "react-i18next": ["react-i18next@15.4.1", "", { "dependencies": { "@babel/runtime": "^7.25.0", "html-parse-stringify": "^3.0.1" }, "peerDependencies": { "i18next": ">= 23.2.3", "react": ">= 16.8.0" } }, "sha512-ahGab+IaSgZmNPYXdV1n+OYky95TGpFwnKRflX/16dY04DsYYKHtVLjeny7sBSCREEcoMbAgSkFiGLF5g5Oofw=="], + "react-is": ["react-is@16.13.1", "", {}, "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ=="], "react-markdown": ["react-markdown@9.1.0", "", { "dependencies": { "@types/hast": "^3.0.0", "@types/mdast": "^4.0.0", "devlop": "^1.0.0", "hast-util-to-jsx-runtime": "^2.0.0", "html-url-attributes": "^3.0.0", "mdast-util-to-hast": "^13.0.0", "remark-parse": "^11.0.0", "remark-rehype": "^11.0.0", "unified": "^11.0.0", "unist-util-visit": "^5.0.0", "vfile": "^6.0.0" }, "peerDependencies": { "@types/react": ">=18", "react": ">=18" } }, "sha512-xaijuJB0kzGiUdG7nc2MOMDUDBWPyGAjZtUrow9XxUeua8IqeP+VlIfAZ3bphpcLTnSZXz6z9jcVC/TCwbfgdw=="], @@ -1271,6 +1279,8 @@ "vite": ["vite@6.1.1", "", { "dependencies": { "esbuild": "^0.24.2", "postcss": "^8.5.2", "rollup": "^4.30.1" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^18.0.0 || ^20.0.0 || >=22.0.0", "jiti": ">=1.21.0", "less": "*", "lightningcss": "^1.21.0", "sass": "*", "sass-embedded": "*", "stylus": "*", "sugarss": "*", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "jiti", "less", "lightningcss", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-4GgM54XrwRfrOp297aIYspIti66k56v16ZnqHvrIM7mG+HjDlAwS7p+Srr7J6fGvEdOJ5JcQ/D9T7HhtdXDTzA=="], + "void-elements": ["void-elements@3.1.0", "", {}, "sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w=="], + "which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="], "which-boxed-primitive": ["which-boxed-primitive@1.1.1", "", { "dependencies": { "is-bigint": "^1.1.0", "is-boolean-object": "^1.2.1", "is-number-object": "^1.1.1", "is-string": "^1.1.1", "is-symbol": "^1.1.1" } }, "sha512-TbX3mj8n0odCBFVlY8AxkqcHASw3L60jIuF8jFP78az3C2YhmGvqbHBpAjTRH2/xqYunrJ9g1jSyjCjpoWzIAA=="], diff --git a/lightrag_webui/package.json b/lightrag_webui/package.json index 578ee36f..97fba74d 100644 --- a/lightrag_webui/package.json +++ b/lightrag_webui/package.json @@ -43,11 +43,13 @@ "cmdk": "^1.0.4", "graphology": "^0.26.0", "graphology-generators": "^0.11.2", + "i18next": "^24.2.2", "lucide-react": "^0.475.0", "minisearch": "^7.1.2", "react": "^19.0.0", "react-dom": "^19.0.0", "react-dropzone": "^14.3.6", + "react-i18next": "^15.4.1", "react-markdown": "^9.1.0", "react-number-format": "^5.4.3", "react-syntax-highlighter": "^15.6.1", diff --git a/lightrag_webui/src/components/ThemeToggle.tsx b/lightrag_webui/src/components/ThemeToggle.tsx index 8e92d862..ff333ff0 100644 --- a/lightrag_webui/src/components/ThemeToggle.tsx +++ b/lightrag_webui/src/components/ThemeToggle.tsx @@ -3,6 +3,7 @@ import useTheme from '@/hooks/useTheme' import { MoonIcon, SunIcon } from 'lucide-react' import { useCallback } from 'react' import { controlButtonVariant } from '@/lib/constants' +import { useTranslation } from 'react-i18next' /** * Component that toggles the theme between light and dark. @@ -11,13 +12,14 @@ export default function ThemeToggle() { const { theme, setTheme } = useTheme() const setLight = useCallback(() => setTheme('light'), [setTheme]) const setDark = useCallback(() => setTheme('dark'), [setTheme]) + const { t } = useTranslation() if (theme === 'dark') { return ( e.preventDefault()}> - Clear documents - Do you really want to clear all documents? + {t('documentPanel.clearDocuments.title')} + {t('documentPanel.clearDocuments.confirm')} diff --git a/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx b/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx index 7149eb28..7f17393c 100644 --- a/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx +++ b/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx @@ -14,8 +14,10 @@ import { errorMessage } from '@/lib/utils' import { uploadDocument } from '@/api/lightrag' import { UploadIcon } from 'lucide-react' +import { useTranslation } from 'react-i18next' export default function UploadDocumentsDialog() { + const { t } = useTranslation() const [open, setOpen] = useState(false) const [isUploading, setIsUploading] = useState(false) const [progresses, setProgresses] = useState>({}) @@ -29,24 +31,24 @@ export default function UploadDocumentsDialog() { filesToUpload.map(async (file) => { try { const result = await uploadDocument(file, (percentCompleted: number) => { - console.debug(`Uploading ${file.name}: ${percentCompleted}%`) + console.debug(t('documentPanel.uploadDocuments.uploading', { name: file.name, percent: percentCompleted })) setProgresses((pre) => ({ ...pre, [file.name]: percentCompleted })) }) if (result.status === 'success') { - toast.success(`Upload Success:\n${file.name} uploaded successfully`) + toast.success(t('documentPanel.uploadDocuments.success', { name: file.name })) } else { - toast.error(`Upload Failed:\n${file.name}\n${result.message}`) + toast.error(t('documentPanel.uploadDocuments.failed', { name: file.name, message: result.message })) } } catch (err) { - toast.error(`Upload Failed:\n${file.name}\n${errorMessage(err)}`) + toast.error(t('documentPanel.uploadDocuments.error', { name: file.name, error: errorMessage(err) })) } }) ) } catch (err) { - toast.error('Upload Failed\n' + errorMessage(err)) + toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) })) } finally { setIsUploading(false) // setOpen(false) @@ -66,21 +68,21 @@ export default function UploadDocumentsDialog() { }} > - e.preventDefault()}> - Upload documents + {t('documentPanel.uploadDocuments.title')} - Drag and drop your documents here or click to browse. + {t('documentPanel.uploadDocuments.description')} { const { isFullScreen, toggle } = useFullScreen() + const { t } = useTranslation() return ( <> {isFullScreen ? ( - ) : ( - )} diff --git a/lightrag_webui/src/components/graph/GraphLabels.tsx b/lightrag_webui/src/components/graph/GraphLabels.tsx index a3849e1f..7bc26c88 100644 --- a/lightrag_webui/src/components/graph/GraphLabels.tsx +++ b/lightrag_webui/src/components/graph/GraphLabels.tsx @@ -5,6 +5,7 @@ import { useSettingsStore } from '@/stores/settings' import { useGraphStore } from '@/stores/graph' import { labelListLimit } from '@/lib/constants' import MiniSearch from 'minisearch' +import { useTranslation } from 'react-i18next' const lastGraph: any = { graph: null, @@ -13,6 +14,7 @@ const lastGraph: any = { } const GraphLabels = () => { + const { t } = useTranslation() const label = useSettingsStore.use.queryLabel() const graph = useGraphStore.use.sigmaGraph() @@ -69,7 +71,7 @@ const GraphLabels = () => { return result.length <= labelListLimit ? result - : [...result.slice(0, labelListLimit), `And ${result.length - labelListLimit} others`] + : [...result.slice(0, labelListLimit), t('graphLabels.andOthers', { count: result.length - labelListLimit })] }, [getSearchEngine] ) @@ -84,14 +86,14 @@ const GraphLabels = () => { className="ml-2" triggerClassName="max-h-8" searchInputClassName="max-h-8" - triggerTooltip="Select query label" + triggerTooltip={t('graphPanel.graphLabels.selectTooltip')} fetcher={fetchData} renderOption={(item) =>
{item}
} getOptionValue={(item) => item} getDisplayValue={(item) =>
{item}
} notFound={
No labels found
} - label="Label" - placeholder="Search labels..." + label={t('graphPanel.graphLabels.label')} + placeholder={t('graphPanel.graphLabels.placeholder')} value={label !== null ? label : ''} onChange={setQueryLabel} /> diff --git a/lightrag_webui/src/components/graph/GraphSearch.tsx b/lightrag_webui/src/components/graph/GraphSearch.tsx index 3edc3ede..bbb8cb5b 100644 --- a/lightrag_webui/src/components/graph/GraphSearch.tsx +++ b/lightrag_webui/src/components/graph/GraphSearch.tsx @@ -9,6 +9,7 @@ import { AsyncSearch } from '@/components/ui/AsyncSearch' import { searchResultLimit } from '@/lib/constants' import { useGraphStore } from '@/stores/graph' import MiniSearch from 'minisearch' +import { useTranslation } from 'react-i18next' interface OptionItem { id: string @@ -44,6 +45,7 @@ export const GraphSearchInput = ({ onFocus?: GraphSearchInputProps['onFocus'] value?: GraphSearchInputProps['value'] }) => { + const { t } = useTranslation() const graph = useGraphStore.use.sigmaGraph() const searchEngine = useMemo(() => { @@ -97,7 +99,7 @@ export const GraphSearchInput = ({ { type: 'message', id: messageId, - message: `And ${result.length - searchResultLimit} others` + message: t('graphPanel.search.message', { count: result.length - searchResultLimit }) } ] }, @@ -118,7 +120,7 @@ export const GraphSearchInput = ({ if (id !== messageId && onFocus) onFocus(id ? { id, type: 'nodes' } : null) }} label={'item'} - placeholder="Search nodes..." + placeholder={t('graphPanel.search.placeholder')} /> ) } diff --git a/lightrag_webui/src/components/graph/LayoutsControl.tsx b/lightrag_webui/src/components/graph/LayoutsControl.tsx index c57b371a..0ed97f2f 100644 --- a/lightrag_webui/src/components/graph/LayoutsControl.tsx +++ b/lightrag_webui/src/components/graph/LayoutsControl.tsx @@ -16,6 +16,7 @@ import { controlButtonVariant } from '@/lib/constants' import { useSettingsStore } from '@/stores/settings' import { GripIcon, PlayIcon, PauseIcon } from 'lucide-react' +import { useTranslation } from 'react-i18next' type LayoutName = | 'Circular' @@ -28,6 +29,7 @@ type LayoutName = const WorkerLayoutControl = ({ layout, autoRunFor }: WorkerLayoutControlProps) => { const sigma = useSigma() const { stop, start, isRunning } = layout + const { t } = useTranslation() /** * Init component when Sigma or component settings change. @@ -61,7 +63,7 @@ const WorkerLayoutControl = ({ layout, autoRunFor }: WorkerLayoutControlProps) = @@ -166,7 +169,7 @@ const LayoutsControl = () => { key={name} className="cursor-pointer text-xs" > - {name} + {t(`graphPanel.sideBar.layoutsControl.layouts.${name}`)} ))} diff --git a/lightrag_webui/src/components/graph/PropertiesView.tsx b/lightrag_webui/src/components/graph/PropertiesView.tsx index dec80460..4571b02b 100644 --- a/lightrag_webui/src/components/graph/PropertiesView.tsx +++ b/lightrag_webui/src/components/graph/PropertiesView.tsx @@ -2,6 +2,7 @@ import { useEffect, useState } from 'react' import { useGraphStore, RawNodeType, RawEdgeType } from '@/stores/graph' import Text from '@/components/ui/Text' import useLightragGraph from '@/hooks/useLightragGraph' +import { useTranslation } from 'react-i18next' /** * Component that view properties of elements in graph. @@ -147,21 +148,22 @@ const PropertyRow = ({ } const NodePropertiesView = ({ node }: { node: NodeType }) => { + const { t } = useTranslation() return (
- +
- + { useGraphStore.getState().setSelectedNode(node.id, true) }} /> - +
- +
{Object.keys(node.properties) .sort() @@ -172,7 +174,7 @@ const NodePropertiesView = ({ node }: { node: NodeType }) => { {node.relationships.length > 0 && ( <>
{node.relationships.map(({ type, id, label }) => { @@ -195,28 +197,29 @@ const NodePropertiesView = ({ node }: { node: NodeType }) => { } const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => { + const { t } = useTranslation() return (
- +
- - {edge.type && } + + {edge.type && } { useGraphStore.getState().setSelectedNode(edge.source, true) }} /> { useGraphStore.getState().setSelectedNode(edge.target, true) }} />
- +
{Object.keys(edge.properties) .sort() diff --git a/lightrag_webui/src/components/graph/Settings.tsx b/lightrag_webui/src/components/graph/Settings.tsx index 67fb1ded..4a4b15a5 100644 --- a/lightrag_webui/src/components/graph/Settings.tsx +++ b/lightrag_webui/src/components/graph/Settings.tsx @@ -10,6 +10,7 @@ import { useSettingsStore } from '@/stores/settings' import { useBackendState } from '@/stores/state' import { SettingsIcon } from 'lucide-react' +import { useTranslation } from "react-i18next"; /** * Component that displays a checkbox with a label. @@ -204,10 +205,12 @@ export default function Settings() { [setTempApiKey] ) + const { t } = useTranslation(); + return ( - @@ -221,7 +224,7 @@ export default function Settings() { @@ -229,12 +232,12 @@ export default function Settings() { @@ -242,12 +245,12 @@ export default function Settings() { @@ -255,51 +258,50 @@ export default function Settings() { -
- +
e.preventDefault()}>
@@ -310,7 +312,7 @@ export default function Settings() { size="sm" className="max-h-full shrink-0" > - Save + {t("graphPanel.sideBar.settings.save")}
diff --git a/lightrag_webui/src/components/graph/StatusCard.tsx b/lightrag_webui/src/components/graph/StatusCard.tsx index 3084d103..e67cbd30 100644 --- a/lightrag_webui/src/components/graph/StatusCard.tsx +++ b/lightrag_webui/src/components/graph/StatusCard.tsx @@ -1,58 +1,60 @@ import { LightragStatus } from '@/api/lightrag' +import { useTranslation } from 'react-i18next' const StatusCard = ({ status }: { status: LightragStatus | null }) => { + const { t } = useTranslation() if (!status) { - return
Status information unavailable
+ return
{t('graphPanel.statusCard.unavailable')}
} return (
-

Storage Info

+

{t('graphPanel.statusCard.storageInfo')}

- Working Directory: + {t('graphPanel.statusCard.workingDirectory')}: {status.working_directory} - Input Directory: + {t('graphPanel.statusCard.inputDirectory')}: {status.input_directory}
-

LLM Configuration

+

{t('graphPanel.statusCard.llmConfig')}

- LLM Binding: + {t('graphPanel.statusCard.llmBinding')}: {status.configuration.llm_binding} - LLM Binding Host: + {t('graphPanel.statusCard.llmBindingHost')}: {status.configuration.llm_binding_host} - LLM Model: + {t('graphPanel.statusCard.llmModel')}: {status.configuration.llm_model} - Max Tokens: + {t('graphPanel.statusCard.maxTokens')}: {status.configuration.max_tokens}
-

Embedding Configuration

+

{t('graphPanel.statusCard.embeddingConfig')}

- Embedding Binding: + {t('graphPanel.statusCard.embeddingBinding')}: {status.configuration.embedding_binding} - Embedding Binding Host: + {t('graphPanel.statusCard.embeddingBindingHost')}: {status.configuration.embedding_binding_host} - Embedding Model: + {t('graphPanel.statusCard.embeddingModel')}: {status.configuration.embedding_model}
-

Storage Configuration

+

{t('graphPanel.statusCard.storageConfig')}

- KV Storage: + {t('graphPanel.statusCard.kvStorage')}: {status.configuration.kv_storage} - Doc Status Storage: + {t('graphPanel.statusCard.docStatusStorage')}: {status.configuration.doc_status_storage} - Graph Storage: + {t('graphPanel.statusCard.graphStorage')}: {status.configuration.graph_storage} - Vector Storage: + {t('graphPanel.statusCard.vectorStorage')}: {status.configuration.vector_storage}
diff --git a/lightrag_webui/src/components/graph/StatusIndicator.tsx b/lightrag_webui/src/components/graph/StatusIndicator.tsx index 3272d9fa..d7a1831f 100644 --- a/lightrag_webui/src/components/graph/StatusIndicator.tsx +++ b/lightrag_webui/src/components/graph/StatusIndicator.tsx @@ -3,8 +3,10 @@ import { useBackendState } from '@/stores/state' import { useEffect, useState } from 'react' import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/Popover' import StatusCard from '@/components/graph/StatusCard' +import { useTranslation } from 'react-i18next' const StatusIndicator = () => { + const { t } = useTranslation() const health = useBackendState.use.health() const lastCheckTime = useBackendState.use.lastCheckTime() const status = useBackendState.use.status() @@ -33,7 +35,7 @@ const StatusIndicator = () => { )} /> - {health ? 'Connected' : 'Disconnected'} + {health ? t('graphPanel.statusIndicator.connected') : t('graphPanel.statusIndicator.disconnected')}
diff --git a/lightrag_webui/src/components/graph/ZoomControl.tsx b/lightrag_webui/src/components/graph/ZoomControl.tsx index 790b4423..0aa55416 100644 --- a/lightrag_webui/src/components/graph/ZoomControl.tsx +++ b/lightrag_webui/src/components/graph/ZoomControl.tsx @@ -3,12 +3,14 @@ import { useCallback } from 'react' import Button from '@/components/ui/Button' import { ZoomInIcon, ZoomOutIcon, FullscreenIcon } from 'lucide-react' import { controlButtonVariant } from '@/lib/constants' +import { useTranslation } from "react-i18next"; /** * Component that provides zoom controls for the graph viewer. */ const ZoomControl = () => { const { zoomIn, zoomOut, reset } = useCamera({ duration: 200, factor: 1.5 }) + const { t } = useTranslation(); const handleZoomIn = useCallback(() => zoomIn(), [zoomIn]) const handleZoomOut = useCallback(() => zoomOut(), [zoomOut]) @@ -16,16 +18,16 @@ const ZoomControl = () => { return ( <> - -
@@ -98,29 +100,29 @@ export default function DocumentManager() { - Uploaded documents - view the uploaded documents here + {t('documentPanel.documentManager.uploadedTitle')} + {t('documentPanel.documentManager.uploadedDescription')} {!docs && ( )} {docs && ( - ID - Summary - Status - Length - Chunks - Created - Updated - Metadata + {t('documentPanel.documentManager.columns.id')} + {t('documentPanel.documentManager.columns.summary')} + {t('documentPanel.documentManager.columns.status')} + {t('documentPanel.documentManager.columns.length')} + {t('documentPanel.documentManager.columns.chunks')} + {t('documentPanel.documentManager.columns.created')} + {t('documentPanel.documentManager.columns.updated')} + {t('documentPanel.documentManager.columns.metadata')} @@ -137,13 +139,13 @@ export default function DocumentManager() { {status === 'processed' && ( - Completed + {t('documentPanel.documentManager.status.completed')} )} {status === 'processing' && ( - Processing + {t('documentPanel.documentManager.status.processing')} )} - {status === 'pending' && Pending} - {status === 'failed' && Failed} + {status === 'pending' && {t('documentPanel.documentManager.status.pending')}} + {status === 'failed' && {t('documentPanel.documentManager.status.failed')}} {doc.error && ( ⚠️ diff --git a/lightrag_webui/src/features/RetrievalTesting.tsx b/lightrag_webui/src/features/RetrievalTesting.tsx index 340255a2..c7fdf2a9 100644 --- a/lightrag_webui/src/features/RetrievalTesting.tsx +++ b/lightrag_webui/src/features/RetrievalTesting.tsx @@ -8,8 +8,10 @@ import { useDebounce } from '@/hooks/useDebounce' import QuerySettings from '@/components/retrieval/QuerySettings' import { ChatMessage, MessageWithError } from '@/components/retrieval/ChatMessage' import { EraserIcon, SendIcon } from 'lucide-react' +import { useTranslation } from 'react-i18next' export default function RetrievalTesting() { + const { t } = useTranslation() const [messages, setMessages] = useState( () => useSettingsStore.getState().retrievalHistory || [] ) @@ -89,7 +91,7 @@ export default function RetrievalTesting() { } } catch (err) { // Handle error - updateAssistantMessage(`Error: Failed to get response\n${errorMessage(err)}`, true) + updateAssistantMessage(`${t('retrievePanel.retrieval.error')}\n${errorMessage(err)}`, true) } finally { // Clear loading and add messages to state setIsLoading(false) @@ -98,7 +100,7 @@ export default function RetrievalTesting() { .setRetrievalHistory([...prevMessages, userMessage, assistantMessage]) } }, - [inputValue, isLoading, messages, setMessages] + [inputValue, isLoading, messages, setMessages, t] ) const debouncedMessages = useDebounce(messages, 100) @@ -117,7 +119,7 @@ export default function RetrievalTesting() {
{messages.length === 0 ? (
- Start a retrieval by typing your query below + {t('retrievePanel.retrieval.startPrompt')}
) : ( messages.map((message, idx) => ( @@ -143,18 +145,18 @@ export default function RetrievalTesting() { size="sm" > - Clear + {t('retrievePanel.retrieval.clear')} setInputValue(e.target.value)} - placeholder="Type your query..." + placeholder={t('retrievePanel.retrieval.placeholder')} disabled={isLoading} />
diff --git a/lightrag_webui/src/features/SiteHeader.tsx b/lightrag_webui/src/features/SiteHeader.tsx index c09ce089..ac3bdd70 100644 --- a/lightrag_webui/src/features/SiteHeader.tsx +++ b/lightrag_webui/src/features/SiteHeader.tsx @@ -4,6 +4,7 @@ import ThemeToggle from '@/components/ThemeToggle' import { TabsList, TabsTrigger } from '@/components/ui/Tabs' import { useSettingsStore } from '@/stores/settings' import { cn } from '@/lib/utils' +import { useTranslation } from 'react-i18next' import { ZapIcon, GithubIcon } from 'lucide-react' @@ -29,21 +30,22 @@ function NavigationTab({ value, currentTab, children }: NavigationTabProps) { function TabsNavigation() { const currentTab = useSettingsStore.use.currentTab() + const { t } = useTranslation() return (
- Documents + {t('header.documents')} - Knowledge Graph + {t('header.knowledgeGraph')} - Retrieval + {t('header.retrieval')} - API + {t('header.api')}
@@ -51,6 +53,7 @@ function TabsNavigation() { } export default function SiteHeader() { + const { t } = useTranslation() return (
@@ -64,7 +67,7 @@ export default function SiteHeader() {