From aa5888042e02625568785cafde8996f0b3d16831 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 23:57:57 +0800 Subject: [PATCH 01/32] Improved file handling and validation for document processing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Enhanced UTF-8 validation for text files • Added content validation checks • Better handling of binary data • Added logging for ignored document IDs • Improved document ID filtering --- lightrag/api/routers/document_routes.py | 28 +++++++++++++++++++++---- lightrag/lightrag.py | 18 +++++++++++++++- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ab5aff96..f7f87c2b 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -215,7 +215,27 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: | ".scss" | ".less" ): - content = file.decode("utf-8") + try: + # Try to decode as UTF-8 + content = file.decode("utf-8") + + # Validate content + if not content or len(content.strip()) == 0: + logger.error(f"Empty content in file: {file_path.name}") + return False + + # Check if content looks like binary data string representation + if content.startswith("b'") or content.startswith('b"'): + logger.error( + f"File {file_path.name} appears to contain binary data representation instead of text" + ) + return False + + except UnicodeDecodeError: + logger.error( + f"File {file_path.name} is not valid UTF-8 encoded text. Please convert it to UTF-8 before processing." + ) + return False case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") @@ -229,7 +249,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: case ".docx": if not pm.is_installed("docx"): pm.install("docx") - from docx import Document + from docx import Document # type: ignore from io import BytesIO docx_file = BytesIO(file) @@ -238,7 +258,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: case ".pptx": if not pm.is_installed("pptx"): pm.install("pptx") - from pptx import Presentation + from pptx import Presentation # type: ignore from io import BytesIO pptx_file = BytesIO(file) @@ -250,7 +270,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: case ".xlsx": if not pm.is_installed("openpyxl"): pm.install("openpyxl") - from openpyxl import load_workbook + from openpyxl import load_workbook # type: ignore from io import BytesIO xlsx_file = BytesIO(file) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8d9c1678..daf5c059 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -670,8 +670,24 @@ class LightRAG: all_new_doc_ids = set(new_docs.keys()) # Exclude IDs of documents that are already in progress unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids) + + # Log ignored document IDs + ignored_ids = [ + doc_id for doc_id in unique_new_doc_ids if doc_id not in new_docs + ] + if ignored_ids: + logger.warning( + f"Ignoring {len(ignored_ids)} document IDs not found in new_docs" + ) + for doc_id in ignored_ids: + logger.warning(f"Ignored document ID: {doc_id}") + # Filter new_docs to only include documents with unique IDs - new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids} + new_docs = { + doc_id: new_docs[doc_id] + for doc_id in unique_new_doc_ids + if doc_id in new_docs + } if not new_docs: logger.info("No new unique documents were found.") From b07181ca39ca034d165f19ddfbae09a45ae738a3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 3 Mar 2025 01:59:20 +0800 Subject: [PATCH 02/32] Remove duplicated run_with_gunicorn.py from project root --- run_with_gunicorn.py | 203 ------------------------------------------- 1 file changed, 203 deletions(-) delete mode 100755 run_with_gunicorn.py diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py deleted file mode 100755 index 2e4e3cf7..00000000 --- a/run_with_gunicorn.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python -""" -Start LightRAG server with Gunicorn -""" - -import os -import sys -import signal -import pipmaster as pm -from lightrag.api.utils_api import parse_args, display_splash_screen -from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data - - -def check_and_install_dependencies(): - """Check and install required dependencies""" - required_packages = [ - "gunicorn", - "tiktoken", - "psutil", - # Add other required packages here - ] - - for package in required_packages: - if not pm.is_installed(package): - print(f"Installing {package}...") - pm.install(package) - print(f"{package} installed successfully") - - -# Signal handler for graceful shutdown -def signal_handler(sig, frame): - print("\n\n" + "=" * 80) - print("RECEIVED TERMINATION SIGNAL") - print(f"Process ID: {os.getpid()}") - print("=" * 80 + "\n") - - # Release shared resources - finalize_share_data() - - # Exit with success status - sys.exit(0) - - -def main(): - # Check and install dependencies - check_and_install_dependencies() - - # Register signal handlers for graceful shutdown - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command - - # Parse all arguments using parse_args - args = parse_args(is_uvicorn_mode=False) - - # Display startup information - display_splash_screen(args) - - print("🚀 Starting LightRAG with Gunicorn") - print(f"🔄 Worker management: Gunicorn (workers={args.workers})") - print("🔍 Preloading app: Enabled") - print("📝 Note: Using Gunicorn's preload feature for shared data initialization") - print("\n\n" + "=" * 80) - print("MAIN PROCESS INITIALIZATION") - print(f"Process ID: {os.getpid()}") - print(f"Workers setting: {args.workers}") - print("=" * 80 + "\n") - - # Import Gunicorn's StandaloneApplication - from gunicorn.app.base import BaseApplication - - # Define a custom application class that loads our config - class GunicornApp(BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} - self.application = app - super().__init__() - - def load_config(self): - # Define valid Gunicorn configuration options - valid_options = { - "bind", - "workers", - "worker_class", - "timeout", - "keepalive", - "preload_app", - "errorlog", - "accesslog", - "loglevel", - "certfile", - "keyfile", - "limit_request_line", - "limit_request_fields", - "limit_request_field_size", - "graceful_timeout", - "max_requests", - "max_requests_jitter", - } - - # Special hooks that need to be set separately - special_hooks = { - "on_starting", - "on_reload", - "on_exit", - "pre_fork", - "post_fork", - "pre_exec", - "pre_request", - "post_request", - "worker_init", - "worker_exit", - "nworkers_changed", - "child_exit", - } - - # Import and configure the gunicorn_config module - import gunicorn_config - - # Set configuration variables in gunicorn_config, prioritizing command line arguments - gunicorn_config.workers = ( - args.workers if args.workers else int(os.getenv("WORKERS", 1)) - ) - - # Bind configuration prioritizes command line arguments - host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") - port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) - gunicorn_config.bind = f"{host}:{port}" - - # Log level configuration prioritizes command line arguments - gunicorn_config.loglevel = ( - args.log_level.lower() - if args.log_level - else os.getenv("LOG_LEVEL", "info") - ) - - # Timeout configuration prioritizes command line arguments - gunicorn_config.timeout = ( - args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) - ) - - # Keepalive configuration - gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) - - # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ( - "true", - "1", - "yes", - "t", - "on", - ): - gunicorn_config.certfile = ( - args.ssl_certfile - if args.ssl_certfile - else os.getenv("SSL_CERTFILE") - ) - gunicorn_config.keyfile = ( - args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") - ) - - # Set configuration options from the module - for key in dir(gunicorn_config): - if key in valid_options: - value = getattr(gunicorn_config, key) - # Skip functions like on_starting and None values - if not callable(value) and value is not None: - self.cfg.set(key, value) - # Set special hooks - elif key in special_hooks: - value = getattr(gunicorn_config, key) - if callable(value): - self.cfg.set(key, value) - - if hasattr(gunicorn_config, "logconfig_dict"): - self.cfg.set( - "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") - ) - - def load(self): - # Import the application - from lightrag.api.lightrag_server import get_application - - return get_application(args) - - # Create the application - app = GunicornApp("") - - # Force workers to be an integer and greater than 1 for multi-process mode - workers_count = int(args.workers) - if workers_count > 1: - # Set a flag to indicate we're in the main process - os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" - initialize_share_data(workers_count) - else: - initialize_share_data(1) - - # Run the application - print("\nStarting Gunicorn with direct Python API...") - app.run() - - -if __name__ == "__main__": - main() From 7a866cbe216a46b97de9a8f931d1338aebe763c9 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 11:48:43 +0100 Subject: [PATCH 03/32] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 903c5c17..f27ddd61 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -9,7 +9,10 @@ import signal import pipmaster as pm from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data +from dotenv import load_dotenv +# Load environment variables from .env file +load_dotenv() def check_and_install_dependencies(): """Check and install required dependencies""" From ff3f29d2406c27d1d28e74ec96a8311858009bda Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:13:01 +0100 Subject: [PATCH 04/32] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index f27ddd61..e7143a39 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -12,6 +12,7 @@ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_dat from dotenv import load_dotenv # Load environment variables from .env file +print("Current folder: {}".format(os.getcwd())) load_dotenv() def check_and_install_dependencies(): From e87feb76bc86a94c97d7407dda7f329455936e0c Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:21:15 +0100 Subject: [PATCH 05/32] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index e7143a39..50dd195d 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -12,8 +12,9 @@ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_dat from dotenv import load_dotenv # Load environment variables from .env file -print("Current folder: {}".format(os.getcwd())) +print(f"Current folder: {os.getcwd()}") load_dotenv() +print(f"Check: {os.getenv('LLM_MODEL')}") def check_and_install_dependencies(): """Check and install required dependencies""" From bda931e1d2abc934b7418e0c2ad4e73734d2366d Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:21:50 +0100 Subject: [PATCH 06/32] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 50dd195d..71844fe0 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -13,7 +13,7 @@ from dotenv import load_dotenv # Load environment variables from .env file print(f"Current folder: {os.getcwd()}") -load_dotenv() +load_dotenv(".env") print(f"Check: {os.getenv('LLM_MODEL')}") def check_and_install_dependencies(): From 52bedc9118892d1ad214cda7ed6164d06f27e574 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:22:37 +0100 Subject: [PATCH 07/32] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 71844fe0..4e5353bf 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -11,10 +11,8 @@ from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data from dotenv import load_dotenv -# Load environment variables from .env file -print(f"Current folder: {os.getcwd()}") +# Updated to use the .env that is inside the current folder load_dotenv(".env") -print(f"Check: {os.getenv('LLM_MODEL')}") def check_and_install_dependencies(): """Check and install required dependencies""" From 7b3e39473065935570466537a0c4b7139fb3d176 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:23:47 +0100 Subject: [PATCH 08/32] Update run_with_gunicorn.py --- lightrag/api/run_with_gunicorn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 4e5353bf..231a1727 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -12,6 +12,7 @@ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_dat from dotenv import load_dotenv # Updated to use the .env that is inside the current folder +# This update allows the user to put a different.env file for each lightrag folder load_dotenv(".env") def check_and_install_dependencies(): From 5680e9ef11ef8403f65823af59a7d52c29548d15 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Mon, 3 Mar 2025 12:24:49 +0100 Subject: [PATCH 09/32] Update lightrag_server.py --- lightrag/api/lightrag_server.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..637595d3 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -20,7 +20,7 @@ from ascii_colors import ASCIIColors from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from dotenv import load_dotenv -from .utils_api import ( +from lightrag.api.utils_api import ( get_api_key_dependency, parse_args, get_default_host, @@ -30,14 +30,14 @@ from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc -from .routers.document_routes import ( +from lightrag.api.routers.document_routes import ( DocumentManager, create_document_routes, run_scanning_process, ) -from .routers.query_routes import create_query_routes -from .routers.graph_routes import create_graph_routes -from .routers.ollama_api import OllamaAPI +from lightrag.api.routers.query_routes import create_query_routes +from lightrag.api.routers.graph_routes import create_graph_routes +from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( @@ -48,7 +48,9 @@ from lightrag.kg.shared_storage import ( ) # Load environment variables -load_dotenv(override=True) +# Updated to use the .env that is inside the current folder +# This update allows the user to put a different.env file for each lightrag folder +load_dotenv(".env", override=True) # Initialize config parser config = configparser.ConfigParser() From 462c27c1672c30f54313f832fac182d1587e8092 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 3 Mar 2025 23:18:41 +0800 Subject: [PATCH 10/32] Refactor logging setup and simplify Gunicorn configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Move logging setup code to utils.py • Provide setup_logger for standalone LightRAG logger intialization --- lightrag/api/gunicorn_config.py | 53 ++------- lightrag/api/lightrag_server.py | 3 + lightrag/lightrag.py | 3 - lightrag/utils.py | 96 +++++++++++++++ run_with_gunicorn.py | 203 -------------------------------- 5 files changed, 112 insertions(+), 246 deletions(-) delete mode 100755 run_with_gunicorn.py diff --git a/lightrag/api/gunicorn_config.py b/lightrag/api/gunicorn_config.py index 7f9b4d58..0594ceae 100644 --- a/lightrag/api/gunicorn_config.py +++ b/lightrag/api/gunicorn_config.py @@ -2,12 +2,15 @@ import os import logging from lightrag.kg.shared_storage import finalize_share_data -from lightrag.api.lightrag_server import LightragPathFilter +from lightrag.utils import setup_logger # Get log directory path from environment variable log_dir = os.getenv("LOG_DIR", os.getcwd()) log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) +# Ensure log directory exists +os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + # Get log file max size and backup count from environment variables log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups @@ -108,6 +111,9 @@ def on_starting(server): except ImportError: print("psutil not installed, skipping memory usage reporting") + # Log the location of the LightRAG log file + print(f"LightRAG log file: {log_file_path}\n") + print("Gunicorn initialization complete, forking workers...\n") @@ -134,51 +140,18 @@ def post_fork(server, worker): Executed after a worker has been forked. This is a good place to set up worker-specific configurations. """ - # Configure formatters - detailed_formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - simple_formatter = logging.Formatter("%(levelname)s: %(message)s") - - def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False): - """Set up a logger with console and file handlers""" - logger_instance = logging.getLogger(logger_name) - logger_instance.setLevel(level) - logger_instance.handlers = [] # Clear existing handlers - logger_instance.propagate = False - - # Add console handler - console_handler = logging.StreamHandler() - console_handler.setFormatter(simple_formatter) - console_handler.setLevel(level) - logger_instance.addHandler(console_handler) - - # Add file handler - file_handler = logging.handlers.RotatingFileHandler( - filename=log_file_path, - maxBytes=log_max_bytes, - backupCount=log_backup_count, - encoding="utf-8", - ) - file_handler.setFormatter(detailed_formatter) - file_handler.setLevel(level) - logger_instance.addHandler(file_handler) - - # Add path filter if requested - if add_filter: - path_filter = LightragPathFilter() - logger_instance.addFilter(path_filter) - # Set up main loggers log_level = loglevel.upper() if loglevel else "INFO" - setup_logger("uvicorn", log_level) - setup_logger("uvicorn.access", log_level, add_filter=True) - setup_logger("lightrag", log_level, add_filter=True) + setup_logger("uvicorn", log_level, add_filter=False, log_file_path=log_file_path) + setup_logger( + "uvicorn.access", log_level, add_filter=True, log_file_path=log_file_path + ) + setup_logger("lightrag", log_level, add_filter=True, log_file_path=log_file_path) # Set up lightrag submodule loggers for name in logging.root.manager.loggerDict: if name.startswith("lightrag."): - setup_logger(name, log_level, add_filter=True) + setup_logger(name, log_level, add_filter=True, log_file_path=log_file_path) # Disable uvicorn.error logger uvicorn_error_logger = logging.getLogger("uvicorn.error") diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..693c6a9f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -437,6 +437,9 @@ def configure_logging(): log_dir = os.getenv("LOG_DIR", os.getcwd()) log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + print(f"\nLightRAG log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + # Get log file max size and backup count from environment variables log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 208bdf3e..adcb1029 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -266,9 +266,6 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): - os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) - logger.info(f"Logger initialized for working directory: {self.working_dir}") - from lightrag.kg.shared_storage import ( initialize_share_data, ) diff --git a/lightrag/utils.py b/lightrag/utils.py index c86ad8c0..bb1d6fae 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -6,6 +6,7 @@ import io import csv import json import logging +import logging.handlers import os import re from dataclasses import dataclass @@ -68,6 +69,101 @@ logger.setLevel(logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING) +class LightragPathFilter(logging.Filter): + """Filter for lightrag logger to filter out frequent path access logs""" + + def __init__(self): + super().__init__() + # Define paths to be filtered + self.filtered_paths = ["/documents", "/health", "/webui/"] + + def filter(self, record): + try: + # Check if record has the required attributes for an access log + if not hasattr(record, "args") or not isinstance(record.args, tuple): + return True + if len(record.args) < 5: + return True + + # Extract method, path and status from the record args + method = record.args[1] + path = record.args[2] + status = record.args[4] + + # Filter out successful GET requests to filtered paths + if ( + method == "GET" + and (status == 200 or status == 304) + and path in self.filtered_paths + ): + return False + + return True + except Exception: + # In case of any error, let the message through + return True + + +def setup_logger( + logger_name: str, + level: str = "INFO", + add_filter: bool = False, + log_file_path: str = None, +): + """Set up a logger with console and file handlers + + Args: + logger_name: Name of the logger to set up + level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + add_filter: Whether to add LightragPathFilter to the logger + log_file_path: Path to the log file. If None, will use current directory/lightrag.log + """ + # Configure formatters + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + simple_formatter = logging.Formatter("%(levelname)s: %(message)s") + + # Get log file path + if log_file_path is None: + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + + # Ensure log directory exists + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + logger_instance = logging.getLogger(logger_name) + logger_instance.setLevel(level) + logger_instance.handlers = [] # Clear existing handlers + logger_instance.propagate = False + + # Add console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(simple_formatter) + console_handler.setLevel(level) + logger_instance.addHandler(console_handler) + + # Add file handler + file_handler = logging.handlers.RotatingFileHandler( + filename=log_file_path, + maxBytes=log_max_bytes, + backupCount=log_backup_count, + encoding="utf-8", + ) + file_handler.setFormatter(detailed_formatter) + file_handler.setLevel(level) + logger_instance.addHandler(file_handler) + + # Add path filter if requested + if add_filter: + path_filter = LightragPathFilter() + logger_instance.addFilter(path_filter) + + class UnlimitedSemaphore: """A context manager that allows unlimited access.""" diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py deleted file mode 100755 index 2e4e3cf7..00000000 --- a/run_with_gunicorn.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python -""" -Start LightRAG server with Gunicorn -""" - -import os -import sys -import signal -import pipmaster as pm -from lightrag.api.utils_api import parse_args, display_splash_screen -from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data - - -def check_and_install_dependencies(): - """Check and install required dependencies""" - required_packages = [ - "gunicorn", - "tiktoken", - "psutil", - # Add other required packages here - ] - - for package in required_packages: - if not pm.is_installed(package): - print(f"Installing {package}...") - pm.install(package) - print(f"{package} installed successfully") - - -# Signal handler for graceful shutdown -def signal_handler(sig, frame): - print("\n\n" + "=" * 80) - print("RECEIVED TERMINATION SIGNAL") - print(f"Process ID: {os.getpid()}") - print("=" * 80 + "\n") - - # Release shared resources - finalize_share_data() - - # Exit with success status - sys.exit(0) - - -def main(): - # Check and install dependencies - check_and_install_dependencies() - - # Register signal handlers for graceful shutdown - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command - - # Parse all arguments using parse_args - args = parse_args(is_uvicorn_mode=False) - - # Display startup information - display_splash_screen(args) - - print("🚀 Starting LightRAG with Gunicorn") - print(f"🔄 Worker management: Gunicorn (workers={args.workers})") - print("🔍 Preloading app: Enabled") - print("📝 Note: Using Gunicorn's preload feature for shared data initialization") - print("\n\n" + "=" * 80) - print("MAIN PROCESS INITIALIZATION") - print(f"Process ID: {os.getpid()}") - print(f"Workers setting: {args.workers}") - print("=" * 80 + "\n") - - # Import Gunicorn's StandaloneApplication - from gunicorn.app.base import BaseApplication - - # Define a custom application class that loads our config - class GunicornApp(BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} - self.application = app - super().__init__() - - def load_config(self): - # Define valid Gunicorn configuration options - valid_options = { - "bind", - "workers", - "worker_class", - "timeout", - "keepalive", - "preload_app", - "errorlog", - "accesslog", - "loglevel", - "certfile", - "keyfile", - "limit_request_line", - "limit_request_fields", - "limit_request_field_size", - "graceful_timeout", - "max_requests", - "max_requests_jitter", - } - - # Special hooks that need to be set separately - special_hooks = { - "on_starting", - "on_reload", - "on_exit", - "pre_fork", - "post_fork", - "pre_exec", - "pre_request", - "post_request", - "worker_init", - "worker_exit", - "nworkers_changed", - "child_exit", - } - - # Import and configure the gunicorn_config module - import gunicorn_config - - # Set configuration variables in gunicorn_config, prioritizing command line arguments - gunicorn_config.workers = ( - args.workers if args.workers else int(os.getenv("WORKERS", 1)) - ) - - # Bind configuration prioritizes command line arguments - host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") - port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) - gunicorn_config.bind = f"{host}:{port}" - - # Log level configuration prioritizes command line arguments - gunicorn_config.loglevel = ( - args.log_level.lower() - if args.log_level - else os.getenv("LOG_LEVEL", "info") - ) - - # Timeout configuration prioritizes command line arguments - gunicorn_config.timeout = ( - args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) - ) - - # Keepalive configuration - gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) - - # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ( - "true", - "1", - "yes", - "t", - "on", - ): - gunicorn_config.certfile = ( - args.ssl_certfile - if args.ssl_certfile - else os.getenv("SSL_CERTFILE") - ) - gunicorn_config.keyfile = ( - args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") - ) - - # Set configuration options from the module - for key in dir(gunicorn_config): - if key in valid_options: - value = getattr(gunicorn_config, key) - # Skip functions like on_starting and None values - if not callable(value) and value is not None: - self.cfg.set(key, value) - # Set special hooks - elif key in special_hooks: - value = getattr(gunicorn_config, key) - if callable(value): - self.cfg.set(key, value) - - if hasattr(gunicorn_config, "logconfig_dict"): - self.cfg.set( - "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") - ) - - def load(self): - # Import the application - from lightrag.api.lightrag_server import get_application - - return get_application(args) - - # Create the application - app = GunicornApp("") - - # Force workers to be an integer and greater than 1 for multi-process mode - workers_count = int(args.workers) - if workers_count > 1: - # Set a flag to indicate we're in the main process - os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" - initialize_share_data(workers_count) - else: - initialize_share_data(1) - - # Run the application - print("\nStarting Gunicorn with direct Python API...") - app.run() - - -if __name__ == "__main__": - main() From b26a574f40253ce6a32380b0f29b6d42d75ab0d6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 01:07:34 +0800 Subject: [PATCH 11/32] Deprecate log_level and log_file_path in LightRAG. - Remove log_level from API initialization - Add warnings for deprecated logging params --- README.md | 18 +++++++++++++++--- lightrag/api/lightrag_server.py | 2 -- lightrag/lightrag.py | 25 ++++++++++++++++++++----- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index abc2f8b3..5e8c5a94 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,9 @@ import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import setup_logger + +setup_logger("lightrag", level="INFO") async def initialize_rag(): rag = LightRAG( @@ -344,6 +347,10 @@ from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_i from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import setup_logger + +# Setup log handler for LightRAG +setup_logger("lightrag", level="INFO") async def initialize_rag(): rag = LightRAG( @@ -640,6 +647,9 @@ export NEO4J_URI="neo4j://localhost:7687" export NEO4J_USERNAME="neo4j" export NEO4J_PASSWORD="password" +# Setup logger for LightRAG +setup_logger("lightrag", level="INFO") + # When you launch the project be sure to override the default KG: NetworkX # by specifying kg="Neo4JStorage". @@ -649,8 +659,12 @@ rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model graph_storage="Neo4JStorage", #<-----------override KG default - log_level="DEBUG" #<-----------override log_level default ) + +# Initialize database connections +await rag.initialize_storages() +# Initialize pipeline status for document processing +await initialize_pipeline_status() ``` see test_neo4j.py for a working example. @@ -859,7 +873,6 @@ Valid modes are: | **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | | **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | | **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | -| **log\_level** | | Log level for application runtime | `logging.DEBUG` | | **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | | **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | | **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | @@ -881,7 +894,6 @@ Valid modes are: | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | -|**log\_dir** | `str` | Directory to store logs. | `./` | diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 693c6a9f..c91f693f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -329,7 +329,6 @@ def create_app(args): "similarity_threshold": 0.95, "use_llm_check": False, }, - log_level=args.log_level, namespace_prefix=args.namespace_prefix, auto_manage_storages_states=False, ) @@ -359,7 +358,6 @@ def create_app(args): "similarity_threshold": 0.95, "use_llm_check": False, }, - log_level=args.log_level, namespace_prefix=args.namespace_prefix, auto_manage_storages_states=False, ) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 4dacac08..114b5735 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import configparser import os +import warnings from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial @@ -85,14 +86,10 @@ class LightRAG: doc_status_storage: str = field(default="JsonDocStatusStorage") """Storage type for tracking document processing statuses.""" - # Logging + # Logging (Deprecated, use setup_logger in utils.py instead) # --- - log_level: int = field(default=logger.level) - """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" - log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log")) - """Log file path.""" # Entity extraction # --- @@ -270,6 +267,24 @@ class LightRAG: initialize_share_data, ) + # Handle deprecated parameters + kwargs = self.__dict__ + if "log_level" in kwargs: + warnings.warn( + "WARNING: log_level parameter is deprecated, use setup_logger in utils.py instead", + UserWarning, + stacklevel=2, + ) + # Remove the attribute to prevent its use + delattr(self, "log_level") + if "log_file_path" in kwargs: + warnings.warn( + "WARNING: log_file_path parameter is deprecated, use setup_logger in utils.py instead", + UserWarning, + stacklevel=2, + ) + delattr(self, "log_file_path") + initialize_share_data() if not os.path.exists(self.working_dir): From 905699429281c576f9abbd29ae8c247b64bcda29 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 01:28:08 +0800 Subject: [PATCH 12/32] Deprecate and remove logging parameters in LightRAG. - Set log_level and log_file_path to None by default - Issue warnings if deprecated parameters are used - Maintain backward compatibility with warnings --- lightrag/lightrag.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 114b5735..21688b7d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -88,8 +88,8 @@ class LightRAG: # Logging (Deprecated, use setup_logger in utils.py instead) # --- - log_level: int = field(default=logger.level) - log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log")) + log_level: int | None = field(default=None) + log_file_path: str | None = field(default=None) # Entity extraction # --- @@ -268,21 +268,23 @@ class LightRAG: ) # Handle deprecated parameters - kwargs = self.__dict__ - if "log_level" in kwargs: + if self.log_level is not None: warnings.warn( "WARNING: log_level parameter is deprecated, use setup_logger in utils.py instead", UserWarning, stacklevel=2, ) - # Remove the attribute to prevent its use - delattr(self, "log_level") - if "log_file_path" in kwargs: + if self.log_file_path is not None: warnings.warn( "WARNING: log_file_path parameter is deprecated, use setup_logger in utils.py instead", UserWarning, stacklevel=2, ) + + # Remove these attributes to prevent their use + if hasattr(self, "log_level"): + delattr(self, "log_level") + if hasattr(self, "log_file_path"): delattr(self, "log_file_path") initialize_share_data() From 0af774a28f92b0fd6c2ba9ebcd8ce49f697a3eab Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 01:28:39 +0800 Subject: [PATCH 13/32] Fix linting --- lightrag/lightrag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 21688b7d..a2d806b6 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -280,7 +280,7 @@ class LightRAG: UserWarning, stacklevel=2, ) - + # Remove these attributes to prevent their use if hasattr(self, "log_level"): delattr(self, "log_level") From bc9905a06177961b6f0e78f1da967e1b45ecf8cf Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 02:28:09 +0800 Subject: [PATCH 14/32] Fix gensim not compatible wtih numpy and scipy problem - Replace numpy with gensim in requirements.txt - Let gensim choose a correct version of numpy and scipy --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a1a1157e..d9a5c68e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ configparser future # Basic modules -numpy +gensim pipmaster pydantic python-dotenv From 61839f311a566531c038da57a0451272eff1d9c3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 10:00:07 +0800 Subject: [PATCH 15/32] Fix package name checks for docx and pptx modules. - Added type ignore for package checks - Corrected docx pptx package name for new version --- lightrag/api/routers/document_routes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ab5aff96..39314233 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -217,7 +217,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ): content = file.decode("utf-8") case ".pdf": - if not pm.is_installed("pypdf2"): + if not pm.is_installed("pypdf2"): # type: ignore pm.install("pypdf2") from PyPDF2 import PdfReader # type: ignore from io import BytesIO @@ -227,7 +227,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: for page in reader.pages: content += page.extract_text() + "\n" case ".docx": - if not pm.is_installed("docx"): + if not pm.is_installed("python-docx"): # type: ignore pm.install("docx") from docx import Document from io import BytesIO @@ -236,7 +236,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: doc = Document(docx_file) content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) case ".pptx": - if not pm.is_installed("pptx"): + if not pm.is_installed("python-pptx"): # type: ignore pm.install("pptx") from pptx import Presentation from io import BytesIO @@ -248,7 +248,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: if hasattr(shape, "text"): content += shape.text + "\n" case ".xlsx": - if not pm.is_installed("openpyxl"): + if not pm.is_installed("openpyxl"): # type: ignore pm.install("openpyxl") from openpyxl import load_workbook from io import BytesIO From b12c05ec0a228ad6b2d99fb0c99d2c62131eb5d3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 12:09:00 +0800 Subject: [PATCH 16/32] fix: api server installation missing MANIFEST.in file - Added MANIFEST.in to include webui files - Removed /webui/ endpoint from lightrag_server.py --- MANIFEST.in | 1 + lightrag/api/lightrag_server.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..44c3aff1 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include lightrag/api/webui * diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..8695d6b6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -410,10 +410,6 @@ def create_app(args): name="webui", ) - @app.get("/webui/") - async def webui_root(): - return FileResponse(static_dir / "index.html") - return app From d7f7c07251edf21d8460b9a91ba31c06aad9314e Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 12:19:40 +0800 Subject: [PATCH 17/32] Fix linting --- lightrag/api/lightrag_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8695d6b6..631fa238 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -6,7 +6,6 @@ from fastapi import ( FastAPI, Depends, ) -from fastapi.responses import FileResponse import asyncio import os import logging From 6c8fa9521477b3a9440f640337b060fa46f2f5c8 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 12:25:07 +0800 Subject: [PATCH 18/32] fix demo --- README.md | 21 ++++--- examples/lightrag_azure_openai_demo.py | 58 +++++++++++-------- examples/lightrag_bedrock_demo.py | 4 ++ examples/lightrag_nvidia_demo.py | 2 +- examples/lightrag_openai_compatible_demo.py | 2 +- ..._openai_compatible_demo_embedding_cache.py | 2 +- examples/lightrag_oracle_demo.py | 2 +- examples/lightrag_tidb_demo.py | 2 +- examples/lightrag_zhipu_postgres_demo.py | 2 +- examples/query_keyword_separation_example.py | 2 +- 10 files changed, 58 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 5e8c5a94..f863d9ed 100644 --- a/README.md +++ b/README.md @@ -655,16 +655,19 @@ setup_logger("lightrag", level="INFO") # Note: Default settings use NetworkX # Initialize LightRAG with Neo4J implementation. -rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model - graph_storage="Neo4JStorage", #<-----------override KG default -) +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + graph_storage="Neo4JStorage", #<-----------override KG default + ) -# Initialize database connections -await rag.initialize_storages() -# Initialize pipeline status for document processing -await initialize_pipeline_status() + # Initialize database connections + await rag.initialize_storages() + # Initialize pipeline status for document processing + await initialize_pipeline_status() + + return rag ``` see test_neo4j.py for a working example. diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py index e0840366..c101383d 100644 --- a/examples/lightrag_azure_openai_demo.py +++ b/examples/lightrag_azure_openai_demo.py @@ -81,34 +81,46 @@ asyncio.run(test_funcs()) embedding_dimension = 3072 -rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=8192, - func=embedding_func, - ), -) -rag.initialize_storages() -initialize_pipeline_status() +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), + ) -book1 = open("./book_1.txt", encoding="utf-8") -book2 = open("./book_2.txt", encoding="utf-8") + await rag.initialize_storages() + await initialize_pipeline_status() -rag.insert([book1.read(), book2.read()]) + return rag -query_text = "What are the main themes?" -print("Result (Naive):") -print(rag.query(query_text, param=QueryParam(mode="naive"))) +def main(): + rag = asyncio.run(initialize_rag()) -print("\nResult (Local):") -print(rag.query(query_text, param=QueryParam(mode="local"))) + book1 = open("./book_1.txt", encoding="utf-8") + book2 = open("./book_2.txt", encoding="utf-8") -print("\nResult (Global):") -print(rag.query(query_text, param=QueryParam(mode="global"))) + rag.insert([book1.read(), book2.read()]) -print("\nResult (Hybrid):") -print(rag.query(query_text, param=QueryParam(mode="hybrid"))) + query_text = "What are the main themes?" + + print("Result (Naive):") + print(rag.query(query_text, param=QueryParam(mode="naive"))) + + print("\nResult (Local):") + print(rag.query(query_text, param=QueryParam(mode="local"))) + + print("\nResult (Global):") + print(rag.query(query_text, param=QueryParam(mode="global"))) + + print("\nResult (Hybrid):") + print(rag.query(query_text, param=QueryParam(mode="hybrid"))) + + +if __name__ == "__main__": + main() diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py index 68e9f962..c7f41677 100644 --- a/examples/lightrag_bedrock_demo.py +++ b/examples/lightrag_bedrock_demo.py @@ -53,3 +53,7 @@ def main(): "What are the top themes in this story?", param=QueryParam(mode=mode) ) ) + + +if __name__ == "__main__": + main() diff --git a/examples/lightrag_nvidia_demo.py b/examples/lightrag_nvidia_demo.py index 6de0814c..0e9259bc 100644 --- a/examples/lightrag_nvidia_demo.py +++ b/examples/lightrag_nvidia_demo.py @@ -125,7 +125,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() # reading file with open("./book.txt", "r", encoding="utf-8") as f: diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index 1c4a7a92..d26a8de3 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -77,7 +77,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() with open("./book.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/examples/lightrag_openai_compatible_demo_embedding_cache.py b/examples/lightrag_openai_compatible_demo_embedding_cache.py index 85408f3b..4638219f 100644 --- a/examples/lightrag_openai_compatible_demo_embedding_cache.py +++ b/examples/lightrag_openai_compatible_demo_embedding_cache.py @@ -81,7 +81,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() with open("./book.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 420f1af0..6663f6a1 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -107,7 +107,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() # Extract and Insert into LightRAG storage with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index f167e9cc..52695560 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -87,7 +87,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index 304c5f2c..e4a20f26 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -59,7 +59,7 @@ async def initialize_rag(): async def main(): # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func diff --git a/examples/query_keyword_separation_example.py b/examples/query_keyword_separation_example.py index cbfdd930..092330f4 100644 --- a/examples/query_keyword_separation_example.py +++ b/examples/query_keyword_separation_example.py @@ -102,7 +102,7 @@ async def initialize_rag(): # Example function demonstrating the new query_with_separate_keyword_extraction usage async def run_example(): # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() book1 = open("./book_1.txt", encoding="utf-8") book2 = open("./book_2.txt", encoding="utf-8") From 23106b81fbaeb9f4ddff9c18881874c458d8ab26 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 12:29:17 +0800 Subject: [PATCH 19/32] fix custom kg demo --- README.md | 70 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index f863d9ed..ed257049 100644 --- a/README.md +++ b/README.md @@ -505,44 +505,58 @@ rag.query_with_separate_keyword_extraction( ```python custom_kg = { + "chunks": [ + { + "content": "Alice and Bob are collaborating on quantum computing research.", + "source_id": "doc-1" + } + ], "entities": [ { - "entity_name": "CompanyA", - "entity_type": "Organization", - "description": "A major technology company", - "source_id": "Source1" + "entity_name": "Alice", + "entity_type": "person", + "description": "Alice is a researcher specializing in quantum physics.", + "source_id": "doc-1" }, { - "entity_name": "ProductX", - "entity_type": "Product", - "description": "A popular product developed by CompanyA", - "source_id": "Source1" + "entity_name": "Bob", + "entity_type": "person", + "description": "Bob is a mathematician.", + "source_id": "doc-1" + }, + { + "entity_name": "Quantum Computing", + "entity_type": "technology", + "description": "Quantum computing utilizes quantum mechanical phenomena for computation.", + "source_id": "doc-1" } ], "relationships": [ { - "src_id": "CompanyA", - "tgt_id": "ProductX", - "description": "CompanyA develops ProductX", - "keywords": "develop, produce", + "src_id": "Alice", + "tgt_id": "Bob", + "description": "Alice and Bob are research partners.", + "keywords": "collaboration research", "weight": 1.0, - "source_id": "Source1" + "source_id": "doc-1" + }, + { + "src_id": "Alice", + "tgt_id": "Quantum Computing", + "description": "Alice conducts research on quantum computing.", + "keywords": "research expertise", + "weight": 1.0, + "source_id": "doc-1" + }, + { + "src_id": "Bob", + "tgt_id": "Quantum Computing", + "description": "Bob researches quantum computing.", + "keywords": "research application", + "weight": 1.0, + "source_id": "doc-1" } - ], - "chunks": [ - { - "content": "ProductX, developed by CompanyA, has revolutionized the market with its cutting-edge features.", - "source_id": "Source1", - }, - { - "content": "PersonA is a prominent researcher at UniversityB, focusing on artificial intelligence and machine learning.", - "source_id": "Source2", - }, - { - "content": "None", - "source_id": "UNKNOWN", - }, - ], + ] } rag.insert_custom_kg(custom_kg) From 0f430ca1a7f6058897c8a1cff098b6630801011c Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 12:42:40 +0800 Subject: [PATCH 20/32] update README.md --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ed257049..57563a1f 100644 --- a/README.md +++ b/README.md @@ -785,7 +785,8 @@ rag.delete_by_doc_id("doc_id") LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. -### Create Entities and Relations +
+ Create Entities and Relations ```python # Create new entity @@ -807,8 +808,10 @@ relation = rag.create_relation("Google", "Gmail", { "weight": 2.0 }) ``` +
-### Edit Entities and Relations +
+ Edit Entities and Relations ```python # Edit an existing entity @@ -830,6 +833,7 @@ updated_relation = rag.edit_relation("Google", "Google Mail", { "weight": 3.0 }) ``` +
All operations are available in both synchronous and asynchronous versions. The asynchronous versions have the prefix "a" (e.g., `acreate_entity`, `aedit_relation`). From 6c39cbf773145fd098bfc7b81c7fa8722c6a1338 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 12:45:35 +0800 Subject: [PATCH 21/32] Add summary language setting by env --- lightrag/lightrag.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a5d3c94b..ea302822 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -36,7 +36,7 @@ from .operate import ( mix_kg_vector_query, naive_query, ) -from .prompt import GRAPH_FIELD_SEP +from .prompt import GRAPH_FIELD_SEP, PROMPTS from .utils import ( EmbeddingFunc, always_get_an_event_loop, @@ -236,7 +236,9 @@ class LightRAG: max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20))) """Maximum number of parallel insert operations.""" - addon_params: dict[str, Any] = field(default_factory=dict) + addon_params: dict[str, Any] = field(default_factory=lambda: { + "language": os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"]) + }) # Storages Management # --- From fd9f71e0eee26189f19448d04678ff5dc0254524 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 13:22:33 +0800 Subject: [PATCH 22/32] fix delete_by_doc_id --- lightrag/kg/json_kv_impl.py | 9 +++++++++ lightrag/kg/tidb_impl.py | 8 ++++++++ lightrag/lightrag.py | 33 +++++++++++++++++++++++++-------- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 8d707899..c0b61a63 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -44,6 +44,15 @@ class JsonKVStorage(BaseKVStorage): ) write_json(data_dict, self._file_name) + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + async with self._storage_lock: + return dict(self._data) + async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._storage_lock: return self._data.get(id) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 4adb0141..51d1c365 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -174,6 +174,14 @@ class TiDBKVStorage(BaseKVStorage): self.db = None ################ QUERY METHODS ################ + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + async with self._storage_lock: + return dict(self._data) async def get_by_id(self, id: str) -> dict[str, Any] | None: """Fetch doc_full data by id.""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a5d3c94b..b2e9845e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1431,14 +1431,22 @@ class LightRAG: logger.debug(f"Starting deletion for document {doc_id}") - doc_to_chunk_id = doc_id.replace("doc", "chunk") + # 2. Get all chunks related to this document + # Find all chunks where full_doc_id equals the current doc_id + all_chunks = await self.text_chunks.get_all() + related_chunks = { + chunk_id: chunk_data + for chunk_id, chunk_data in all_chunks.items() + if isinstance(chunk_data, dict) + and chunk_data.get("full_doc_id") == doc_id + } - # 2. Get all related chunks - chunks = await self.text_chunks.get_by_id(doc_to_chunk_id) - if not chunks: + if not related_chunks: + logger.warning(f"No chunks found for document {doc_id}") return - chunk_ids = {chunks["full_doc_id"].replace("doc", "chunk")} + # Get all related chunk IDs + chunk_ids = set(related_chunks.keys()) logger.debug(f"Found {len(chunk_ids)} chunks to delete") # 3. Before deleting, check the related entities and relationships for these chunks @@ -1626,9 +1634,18 @@ class LightRAG: logger.warning(f"Document {doc_id} still exists in full_docs") # Verify if chunks have been deleted - remaining_chunks = await self.text_chunks.get_by_id(doc_to_chunk_id) - if remaining_chunks: - logger.warning(f"Found {len(remaining_chunks)} remaining chunks") + all_remaining_chunks = await self.text_chunks.get_all() + remaining_related_chunks = { + chunk_id: chunk_data + for chunk_id, chunk_data in all_remaining_chunks.items() + if isinstance(chunk_data, dict) + and chunk_data.get("full_doc_id") == doc_id + } + + if remaining_related_chunks: + logger.warning( + f"Found {len(remaining_related_chunks)} remaining chunks" + ) # Verify entities and relationships for chunk_id in chunk_ids: From 06b2124dd0e9774c5d81352060d7c5ca4a4a5ce8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 14:02:14 +0800 Subject: [PATCH 23/32] Fix linting --- lightrag/lightrag.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ea302822..27fdafda 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -236,9 +236,11 @@ class LightRAG: max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20))) """Maximum number of parallel insert operations.""" - addon_params: dict[str, Any] = field(default_factory=lambda: { - "language": os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"]) - }) + addon_params: dict[str, Any] = field( + default_factory=lambda: { + "language": os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"]) + } + ) # Storages Management # --- From 0679ca4055d36dfd53afcb9ab87ea5d4c056cd31 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 14:20:55 +0800 Subject: [PATCH 24/32] Update neo4j_impl.py --- lightrag/kg/neo4j_impl.py | 92 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index dccee330..fec39138 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -690,8 +690,98 @@ class Neo4JStorage(BaseGraphStorage): labels.append(record["label"]) return labels + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + label = await self._ensure_label(node_id) + + async def _do_delete(tx: AsyncManagedTransaction): + query = f""" + MATCH (n:`{label}`) + DETACH DELETE n + """ + await tx.run(query) + logger.debug(f"Deleted node with label '{label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + source_label = await self._ensure_label(source) + target_label = await self._ensure_label(target) + + async def _do_delete_edge(tx: AsyncManagedTransaction): + query = f""" + MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + DELETE r + """ + await tx.run(query) + logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str From 3a2a6368628fd2d54851ed6b1de8026cdf3cf608 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 15:50:53 +0800 Subject: [PATCH 25/32] Implement the missing methods. --- lightrag/kg/age_impl.py | 248 ++++++++++++++++++++++++++++++- lightrag/kg/chroma_impl.py | 34 ++++- lightrag/kg/gremlin_impl.py | 279 ++++++++++++++++++++++++++++++++++- lightrag/kg/milvus_impl.py | 83 ++++++++++- lightrag/kg/mongo_impl.py | 109 +++++++++++++- lightrag/kg/oracle_impl.py | 255 +++++++++++++++++++++++++++++++- lightrag/kg/postgres_impl.py | 243 +++++++++++++++++++++++++++++- lightrag/kg/qdrant_impl.py | 89 ++++++++++- lightrag/kg/redis_impl.py | 78 +++++++++- lightrag/kg/tidb_impl.py | 180 +++++++++++++++++++++- lightrag/lightrag.py | 48 ++++++ 11 files changed, 1603 insertions(+), 43 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 97b3825d..c6b98221 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np import pipmaster as pm -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from tenacity import ( retry, @@ -613,20 +613,258 @@ class AGEStorage(BaseGraphStorage): await self._driver.putconn(connection) async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + entity_name_label = node_id.strip('"') + + query = """ + MATCH (n:`{label}`) + DETACH DELETE n + """ + params = {"label": AGEStorage._encode_graph_label(entity_name_label)} + try: + await self._query(query, **params) + logger.debug(f"Deleted node with label '{entity_name_label}'") + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + entity_name_label_source = source.strip('"') + entity_name_label_target = target.strip('"') + + query = """ + MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) + DELETE r + """ + params = { + "src_label": AGEStorage._encode_graph_label(entity_name_label_source), + "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target) + } + try: + await self._query(query, **params) + logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'") + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """Embed nodes using the specified algorithm + + Args: + algorithm: Name of the embedding algorithm + + Returns: + tuple: (embedding matrix, list of node identifiers) + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """Get all node labels in the database + + Returns: + ["label1", "label2", ...] # Alphabetically sorted label list + """ + query = """ + MATCH (n) + RETURN DISTINCT labels(n) AS node_labels + """ + results = await self._query(query) + + all_labels = [] + for record in results: + if record and "node_labels" in record: + for label in record["node_labels"]: + if label: + # Decode label + decoded_label = AGEStorage._decode_graph_label(label) + all_labels.append(decoded_label) + + # Remove duplicates and sort + return sorted(list(set(all_labels))) async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a connected subgraph of nodes where the label includes the specified 'node_label'. + Maximum number of nodes is constrained by the environment variable 'MAX_GRAPH_NODES' (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence (nodes containing the specified label string) + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes + + Args: + node_label: String to match in node labels (will match any node containing this string in its label) + max_depth: Maximum depth of the graph. Defaults to 5. + + Returns: + KnowledgeGraph: Complete connected subgraph for specified node + """ + max_graph_nodes = int(os.getenv("MAX_GRAPH_NODES", 1000)) + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + # Handle special case for "*" label + if node_label == "*": + # Query all nodes and sort by degree + query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, count(r) AS degree + ORDER BY degree DESC + LIMIT {max_nodes} + RETURN n, degree + """ + params = {"max_nodes": max_graph_nodes} + nodes_result = await self._query(query, **params) + + # Add nodes to result + node_ids = [] + for record in nodes_result: + if "n" in record: + node = record["n"] + node_id = str(node.get("id", "")) + if node_id not in seen_nodes: + node_properties = {k: v for k, v in node.items()} + node_label = node.get("label", "") + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_label], + properties=node_properties + ) + ) + seen_nodes.add(node_id) + node_ids.append(node_id) + + # Query edges between these nodes + if node_ids: + edges_query = """ + MATCH (a)-[r]->(b) + WHERE a.id IN {node_ids} AND b.id IN {node_ids} + RETURN a, r, b + """ + edges_params = {"node_ids": node_ids} + edges_result = await self._query(edges_query, **edges_params) + + # Add edges to result + for record in edges_result: + if "r" in record and "a" in record and "b" in record: + source = record["a"].get("id", "") + target = record["b"].get("id", "") + edge_id = f"{source}-{target}" + if edge_id not in seen_edges: + edge_properties = {k: v for k, v in record["r"].items()} + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=source, + target=target, + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + else: + # For specific label, use partial matching + entity_name_label = node_label.strip('"') + encoded_label = AGEStorage._encode_graph_label(entity_name_label) + + # Find matching start nodes + start_query = """ + MATCH (n:`{label}`) + RETURN n + """ + start_params = {"label": encoded_label} + start_nodes = await self._query(start_query, **start_params) + + if not start_nodes: + logger.warning(f"No nodes found with label '{entity_name_label}'!") + return result + + # Traverse graph from each start node + for start_node_record in start_nodes: + if "n" in start_node_record: + start_node = start_node_record["n"] + start_id = str(start_node.get("id", "")) + + # Use BFS to traverse graph + query = """ + MATCH (start:`{label}`) + CALL { + MATCH path = (start)-[*0..{max_depth}]->(n) + RETURN nodes(path) AS path_nodes, relationships(path) AS path_rels + } + RETURN DISTINCT path_nodes, path_rels + """ + params = {"label": encoded_label, "max_depth": max_depth} + results = await self._query(query, **params) + + # Extract nodes and edges from results + for record in results: + if "path_nodes" in record: + # Process nodes + for node in record["path_nodes"]: + node_id = str(node.get("id", "")) + if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes: + node_properties = {k: v for k, v in node.items()} + node_label = node.get("label", "") + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_label], + properties=node_properties + ) + ) + seen_nodes.add(node_id) + + if "path_rels" in record: + # Process edges + for rel in record["path_rels"]: + source = str(rel.get("start_id", "")) + target = str(rel.get("end_id", "")) + edge_id = f"{source}-{target}" + if edge_id not in seen_edges: + edge_properties = {k: v for k, v in rel.items()} + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=rel.get("label", "DIRECTED"), + source=source, + target=target, + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result async def index_done_callback(self) -> None: # AGES handles persistence automatically diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 3b726c8b..d36e6d7c 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -193,7 +193,37 @@ class ChromaVectorDBStorage(BaseVectorStorage): pass async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its ID. + + Args: + entity_name: The ID of the entity to delete + """ + try: + logger.info(f"Deleting entity with ID {entity_name} from {self.namespace}") + self._collection.delete(ids=[entity_name]) + except Exception as e: + logger.error(f"Error during entity deletion: {str(e)}") + raise async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity and its relations by ID. + In vector DB context, this is equivalent to delete_entity. + + Args: + entity_name: The ID of the entity to delete + """ + await self.delete_entity(entity_name) + + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + self._collection.delete(ids=ids) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + raise diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 3a26401d..4d343bb5 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -16,7 +16,7 @@ from tenacity import ( wait_exponential, ) -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from ..base import BaseGraphStorage @@ -396,17 +396,286 @@ class GremlinStorage(BaseGraphStorage): print("Implemented but never called.") async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified entity_name + + Args: + node_id: The entity_name of the node to delete + """ + entity_name = GremlinStorage._fix_name(node_id) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) + .drop() + """ + try: + await self._query(query) + logger.debug( + "{%s}: Deleted node with entity_name '%s'", + inspect.currentframe().f_code.co_name, + entity_name + ) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """ + Embed nodes using the specified algorithm. + Currently, only node2vec is supported but never called. + + Args: + algorithm: The name of the embedding algorithm to use + + Returns: + A tuple of (embeddings, node_ids) + + Raises: + NotImplementedError: If the specified algorithm is not supported + ValueError: If the algorithm is not supported + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """ + Get all node entity_names in the graph + Returns: + [entity_name1, entity_name2, ...] # Alphabetically sorted entity_name list + """ + query = f"""g + .V().has('graph', {self.graph_name}) + .values('entity_name') + .dedup() + .order() + """ + try: + result = await self._query(query) + labels = result if result else [] + logger.debug( + "{%s}: Retrieved %d labels", + inspect.currentframe().f_code.co_name, + len(labels) + ) + return labels + except Exception as e: + logger.error(f"Error retrieving labels: {str(e)}") + return [] async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + + Args: + node_label: Entity name of the starting node + max_depth: Maximum depth of the subgraph + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + # Get maximum number of graph nodes from environment variable, default is 1000 + MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + + entity_name = GremlinStorage._fix_name(node_label) + + # Handle special case for "*" label + if node_label == "*": + # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) + query = f"""g + .V().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .elementMap() + """ + nodes_result = await self._query(query) + + # Add nodes to result + for node_data in nodes_result: + node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + if str(node_id) in seen_nodes: + continue + + # Create node with properties + node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node_id), + labels=[str(node_id)], + properties=node_properties + ) + ) + seen_nodes.add(str(node_id)) + + # Get and add edges + if nodes_result: + query = f"""g + .V().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .outE() + .inV().has('graph', {self.graph_name}) + .limit({MAX_GRAPH_NODES}) + .path() + .by(elementMap()) + .by(elementMap()) + .by(elementMap()) + """ + edges_result = await self._query(query) + + for path in edges_result: + if len(path) >= 3: # source -> edge -> target + source = path[0] + edge_data = path[1] + target = path[2] + + source_id = source.get('entity_name', str(source.get('id', ''))) + target_id = target.get('entity_name', str(target.get('id', ''))) + + edge_id = f"{source_id}-{target_id}" + if edge_id in seen_edges: + continue + + # Create edge with properties + edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source_id), + target=str(target_id), + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + else: + # Search for specific node and get its neighborhood + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name}) + .repeat(__.both().simplePath().dedup()) + .times({max_depth}) + .emit() + .dedup() + .limit({MAX_GRAPH_NODES}) + .elementMap() + """ + nodes_result = await self._query(query) + + # Add nodes to result + for node_data in nodes_result: + node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + if str(node_id) in seen_nodes: + continue + + # Create node with properties + node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node_id), + labels=[str(node_id)], + properties=node_properties + ) + ) + seen_nodes.add(str(node_id)) + + # Get edges between the nodes in the result + if nodes_result: + node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result] + node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids]) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', within({node_ids_query})) + .outE() + .where(inV().has('graph', {self.graph_name}) + .has('entity_name', within({node_ids_query}))) + .path() + .by(elementMap()) + .by(elementMap()) + .by(elementMap()) + """ + edges_result = await self._query(query) + + for path in edges_result: + if len(path) >= 3: # source -> edge -> target + source = path[0] + edge_data = path[1] + target = path[2] + + source_id = source.get('entity_name', str(source.get('id', ''))) + target_id = target.get('entity_name', str(target.get('id', ''))) + + edge_id = f"{source_id}-{target_id}" + if edge_id in seen_edges: + continue + + # Create edge with properties + edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source_id), + target=str(target_id), + properties=edge_properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + "Subgraph query successful | Node count: %d | Edge count: %d", + len(result.nodes), + len(result.edges) + ) + return result + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node entity_names to be deleted + """ + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + entity_name_source = GremlinStorage._fix_name(source) + entity_name_target = GremlinStorage._fix_name(target) + + query = f"""g + .V().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_source}) + .outE() + .where(inV().has('graph', {self.graph_name}) + .has('entity_name', {entity_name_target})) + .drop() + """ + try: + await self._query(query) + logger.debug( + "{%s}: Deleted edge from '%s' to '%s'", + inspect.currentframe().f_code.co_name, + entity_name_source, + entity_name_target + ) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 33a5c12b..2ad4da18 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -3,7 +3,7 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np -from lightrag.utils import logger +from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage import pipmaster as pm @@ -124,7 +124,84 @@ class MilvusVectorDBStorage(BaseVectorStorage): pass async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity from the vector database + + Args: + entity_name: The name of the entity to delete + """ + try: + # Compute entity ID from name + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity from Milvus collection + result = self._client.delete( + collection_name=self.namespace, + pks=[entity_id] + ) + + if result and result.get("delete_count", 0) > 0: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Search for relations where entity is either source or target + expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' + + # Find all relations involving this entity + results = self._client.query( + collection_name=self.namespace, + filter=expr, + output_fields=["id"] + ) + + if not results or len(results) == 0: + logger.debug(f"No relations found for entity {entity_name}") + return + + # Extract IDs of relations to delete + relation_ids = [item["id"] for item in results] + logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") + + # Delete the relations + if relation_ids: + delete_result = self._client.delete( + collection_name=self.namespace, + pks=relation_ids + ) + + logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}") + + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") + + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + # Delete vectors by IDs + result = self._client.delete( + collection_name=self.namespace, + pks=ids + ) + + if result and result.get("delete_count", 0) > 0: + logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}") + else: + logger.debug(f"No vectors were deleted from {self.namespace}") + + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 0048b384..3afd2b44 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -15,7 +15,7 @@ from ..base import ( DocStatusStorage, ) from ..namespace import NameSpace, is_namespace -from ..utils import logger +from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm @@ -333,7 +333,7 @@ class MongoGraphStorage(BaseGraphStorage): Check if there's a direct single-hop edge from source_node_id to target_node_id. We'll do a $graphLookup with maxDepth=0 from the source node—meaning - “Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1 + "Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1 and then see if the target node is in the "reachableNodes" at depth=0. But typically for a direct edge, we might just do a find_one. @@ -795,6 +795,52 @@ class MongoGraphStorage(BaseGraphStorage): # Mongo handles persistence automatically pass + async def remove_nodes(self, nodes: list[str]) -> None: + """Delete multiple nodes + + Args: + nodes: List of node IDs to be deleted + """ + logger.info(f"Deleting {len(nodes)} nodes") + if not nodes: + return + + # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) + await self.collection.update_many( + {}, + {"$pull": {"edges": {"target": {"$in": nodes}}}} + ) + + # 2. Delete the node documents + await self.collection.delete_many({"_id": {"$in": nodes}}) + + logger.debug(f"Successfully deleted nodes: {nodes}") + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + logger.info(f"Deleting {len(edges)} edges") + if not edges: + return + + update_tasks = [] + for source, target in edges: + # Remove edge pointing to target from source node's edges array + update_tasks.append( + self.collection.update_one( + {"_id": source}, + {"$pull": {"edges": {"target": target}}} + ) + ) + + if update_tasks: + await asyncio.gather(*update_tasks) + + logger.debug(f"Successfully deleted edges: {edges}") + @final @dataclass @@ -932,11 +978,66 @@ class MongoVectorDBStorage(BaseVectorStorage): # Mongo handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + if not ids: + return + + try: + result = await self._data.delete_many({"_id": {"$in": ids}}) + logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}") + except PyMongoError as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its name + + Args: + entity_name: Name of the entity to delete + """ + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + result = await self._data.delete_one({"_id": entity_id}) + if result.deleted_count > 0: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except PyMongoError as e: + logger.error(f"Error deleting entity {entity_name}: {str(e)}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Find relations where entity appears as source or target + relations_cursor = self._data.find( + {"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]} + ) + relations = await relations_cursor.to_list(length=None) + + if not relations: + logger.debug(f"No relations found for entity {entity_name}") + return + + # Extract IDs of relations to delete + relation_ids = [relation["_id"] for relation in relations] + logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") + + # Delete the relations + result = await self._data.delete_many({"_id": {"$in": relation_ids}}) + logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") + except PyMongoError as e: + logger.error(f"Error deleting relations for {entity_name}: {str(e)}") async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index af2ededb..d189679e 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -8,7 +8,7 @@ from typing import Any, Union, final import numpy as np import configparser -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..base import ( BaseGraphStorage, @@ -442,11 +442,55 @@ class OracleVectorDBStorage(BaseVectorStorage): # Oracles handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + try: + SQL = SQL_TEMPLATES["delete_vectors"].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + raise + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete entity by name + + Args: + entity_name: Name of the entity to delete + """ + try: + SQL = SQL_TEMPLATES["delete_entity"] + params = {"workspace": self.db.workspace, "entity_name": entity_name} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + raise async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations connected to an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + SQL = SQL_TEMPLATES["delete_entity_relations"] + params = {"workspace": self.db.workspace, "entity_name": entity_name} + await self.db.execute(SQL, params) + logger.info(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") + raise @final @@ -668,15 +712,206 @@ class OracleGraphStorage(BaseGraphStorage): return res async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node from the graph + + Args: + node_id: ID of the node to delete + """ + try: + # First delete all relations connected to this node + delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] + params_relations = {"workspace": self.db.workspace, "entity_name": node_id} + await self.db.execute(delete_relations_sql, params_relations) + + # Then delete the node itself + delete_node_sql = SQL_TEMPLATES["delete_entity"] + params_node = {"workspace": self.db.workspace, "entity_name": node_id} + await self.db.execute(delete_node_sql, params_node) + + logger.info(f"Successfully deleted node {node_id} and all its relationships") + except Exception as e: + logger.error(f"Error deleting node {node_id}: {e}") + raise async def get_all_labels(self) -> list[str]: - raise NotImplementedError + """Get all unique entity types (labels) in the graph + + Returns: + List of unique entity types/labels + """ + try: + SQL = """ + SELECT DISTINCT entity_type + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY entity_type + """ + params = {"workspace": self.db.workspace} + results = await self.db.query(SQL, params, multirows=True) + + if results: + labels = [row["entity_type"] for row in results] + return labels + else: + return [] + except Exception as e: + logger.error(f"Error retrieving entity types: {e}") + return [] async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """Retrieve a connected subgraph starting from nodes matching the given label + + Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable. + Prioritizes nodes by: + 1. Nodes matching the specified label + 2. Nodes directly connected to matching nodes + 3. Node degree (number of connections) + + Args: + node_label: Label to match for starting nodes (use "*" for all nodes) + max_depth: Maximum depth of traversal from starting nodes + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + + try: + # Define maximum number of nodes to return + max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000)) + + if node_label == "*": + # For "*" label, get all nodes up to the limit + nodes_sql = """ + SELECT name, entity_type, description, source_chunk_id + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY id + FETCH FIRST :limit ROWS ONLY + """ + nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes} + nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) + else: + # For specific label, find matching nodes and related nodes + nodes_sql = """ + WITH matching_nodes AS ( + SELECT name + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%') + ) + SELECT n.name, n.entity_type, n.description, n.source_chunk_id, + CASE + WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2 + WHEN EXISTS ( + SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace + AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes)) + OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes))) + ) THEN 1 + ELSE 0 + END AS priority, + (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace + AND (e.source_name = n.name OR e.target_name = n.name)) AS degree + FROM LIGHTRAG_GRAPH_NODES n + WHERE workspace = :workspace + ORDER BY priority DESC, degree DESC + FETCH FIRST :limit ROWS ONLY + """ + nodes_params = { + "workspace": self.db.workspace, + "node_label": node_label, + "limit": max_graph_nodes + } + nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) + + if not nodes: + logger.warning(f"No nodes found matching '{node_label}'") + return result + + # Create mapping of node IDs to be used to filter edges + node_names = [node["name"] for node in nodes] + + # Add nodes to result + seen_nodes = set() + for node in nodes: + node_id = node["name"] + if node_id in seen_nodes: + continue + + # Create node properties dictionary + properties = { + "entity_type": node["entity_type"], + "description": node["description"] or "", + "source_id": node["source_chunk_id"] or "" + } + + # Add node to result + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node["entity_type"]], + properties=properties + ) + ) + seen_nodes.add(node_id) + + # Get edges between these nodes + edges_sql = """ + SELECT source_name, target_name, weight, keywords, description, source_chunk_id + FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) + AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) + ORDER BY id + """ + edges_params = { + "workspace": self.db.workspace, + "node_names": node_names + } + edges = await self.db.query(edges_sql, edges_params, multirows=True) + + # Add edges to result + seen_edges = set() + for edge in edges: + source = edge["source_name"] + target = edge["target_name"] + edge_id = f"{source}-{target}" + + if edge_id in seen_edges: + continue + + # Create edge properties dictionary + properties = { + "weight": edge["weight"] or 0.0, + "keywords": edge["keywords"] or "", + "description": edge["description"] or "", + "source_id": edge["source_chunk_id"] or "" + } + + # Add edge to result + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="RELATED", + source=source, + target=target, + properties=properties + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except Exception as e: + logger.error(f"Error retrieving knowledge graph: {e}") + + return result N_T = { @@ -927,4 +1162,12 @@ SQL_TEMPLATES = { select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) )""", + # SQL for deletion + "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})", + "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name", + "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)", + "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph + MATCH (a) + WHERE a.workspace=:workspace AND a.name=:node_id + ACTION DELETE a)""", } diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 51044be5..7ce2b427 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -7,7 +7,7 @@ from typing import Any, Union, final import numpy as np import configparser -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import sys from tenacity import ( @@ -512,11 +512,66 @@ class PGVectorStorage(BaseVectorStorage): # PG handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs from the storage. + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for vector deletion: {self.namespace}") + return + + ids_list = ",".join([f"'{id}'" for id in ids]) + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" + + try: + await self.db.execute(delete_sql, {"workspace": self.db.workspace}) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its name from the vector storage. + + Args: + entity_name: The name of the entity to delete + """ + try: + # Construct SQL to delete the entity + delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + WHERE workspace=$1 AND entity_name=$2""" + + await self.db.execute( + delete_sql, + {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity. + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Delete relations where the entity is either the source or target + delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" + + await self.db.execute( + delete_sql, + {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") @final @@ -1086,20 +1141,192 @@ class PGGraphStorage(BaseGraphStorage): print("Implemented but never called.") async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """ + Delete a node from the graph. + + Args: + node_id (str): The ID of the node to delete. + """ + label = self._encode_graph_label(node_id.strip('"')) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, label) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node deletion: {%s}", e) + raise + + async def remove_nodes(self, node_ids: list[str]) -> None: + """ + Remove multiple nodes from the graph. + + Args: + node_ids (list[str]): A list of node IDs to remove. + """ + encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids] + node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + WHERE n.node_id IN [%s] + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, node_id_list) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node removal: {%s}", e) + raise + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """ + Remove multiple edges from the graph. + + Args: + edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). + """ + encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges] + edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity)-[r]->(b:Entity) + WHERE [a.node_id, b.node_id] IN [%s] + DELETE r + $$) AS (r agtype)""" % (self.graph_name, edge_list) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during edge removal: {%s}", e) + raise + + async def get_all_labels(self) -> list[str]: + """ + Get all labels (node IDs) in the graph. + + Returns: + list[str]: A list of all labels in the graph. + """ + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + RETURN DISTINCT n.node_id AS label + $$) AS (label text)""" % self.graph_name + + results = await self._query(query) + labels = [self._decode_graph_label(result["label"]) for result in results] + + return labels async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError + """ + Generate node embeddings using the specified algorithm. - async def get_all_labels(self) -> list[str]: - raise NotImplementedError + Args: + algorithm (str): The name of the embedding algorithm to use. + + Returns: + tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs. + """ + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Unsupported embedding algorithm: {algorithm}") + + embed_func = self._node_embed_algorithms[algorithm] + return await embed_func() async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. + + Args: + node_label (str): The label of the node to start from. If "*", the entire graph is returned. + max_depth (int): The maximum depth to traverse from the starting node. + + Returns: + KnowledgeGraph: The retrieved subgraph. + """ + MAX_GRAPH_NODES = 1000 + + if node_label == "*": + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]->(m:Entity) + RETURN n, r, m + LIMIT %d + $$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES) + else: + encoded_node_label = self._encode_graph_label(node_label.strip('"')) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + OPTIONAL MATCH p = (n)-[*..%d]-(m) + RETURN nodes(p) AS nodes, relationships(p) AS relationships + LIMIT %d + $$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES) + + results = await self._query(query) + + nodes = set() + edges = [] + + for result in results: + if node_label == "*": + if result["n"]: + node = result["n"] + nodes.add(self._decode_graph_label(node["node_id"])) + if result["m"]: + node = result["m"] + nodes.add(self._decode_graph_label(node["node_id"])) + if result["r"]: + edge = result["r"] + src_id = self._decode_graph_label(edge["start_id"]) + tgt_id = self._decode_graph_label(edge["end_id"]) + edges.append((src_id, tgt_id)) + else: + if result["nodes"]: + for node in result["nodes"]: + nodes.add(self._decode_graph_label(node["node_id"])) + if result["relationships"]: + for edge in result["relationships"]: + src_id = self._decode_graph_label(edge["start_id"]) + tgt_id = self._decode_graph_label(edge["end_id"]) + edges.append((src_id, tgt_id)) + + kg = KnowledgeGraph( + nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes], + edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges], + ) + + return kg + + async def get_all_labels(self) -> list[str]: + """ + Get all node labels in the graph + Returns: + [label1, label2, ...] # Alphabetically sorted label list + """ + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity) + RETURN DISTINCT n.node_id AS label + ORDER BY label + $$) AS (label agtype)""" % (self.graph_name) + + try: + results = await self._query(query) + labels = [] + for record in results: + if record["label"]: + labels.append(self._decode_graph_label(record["label"])) + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + return [] async def drop(self) -> None: """Drop the storage""" diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index b08f0b62..e3488caa 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Any, final +from typing import Any, final, List from dataclasses import dataclass import numpy as np import hashlib @@ -141,8 +141,91 @@ class QdrantVectorDBStorage(BaseVectorStorage): # Qdrant handles persistence automatically pass + async def delete(self, ids: List[str]) -> None: + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + # Convert regular ids to Qdrant compatible ids + qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids] + # Delete points from the collection + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=qdrant_ids, + ), + wait=True + ) + logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by name + + Args: + entity_name: Name of the entity to delete + """ + try: + # Generate the entity ID + entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity point from the collection + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=[entity_id], + ), + wait=True + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Find relations where the entity is either source or target + results = self._client.scroll( + collection_name=self.namespace, + scroll_filter=models.Filter( + should=[ + models.FieldCondition( + key="src_id", + match=models.MatchValue(value=entity_name) + ), + models.FieldCondition( + key="tgt_id", + match=models.MatchValue(value=entity_name) + ) + ] + ), + with_payload=True, + limit=1000 # Adjust as needed for your use case + ) + + # Extract points that need to be deleted + relation_points = results[0] + ids_to_delete = [point.id for point in relation_points] + + if ids_to_delete: + # Delete the relations + self._client.delete( + collection_name=self.namespace, + points_selector=models.PointIdsList( + points=ids_to_delete, + ), + wait=True + ) + logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}") + else: + logger.debug(f"No relations found for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 7e177346..bb42b367 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -9,7 +9,7 @@ if not pm.is_installed("redis"): # aioredis is a depricated library, replaced with redis from redis.asyncio import Redis -from lightrag.utils import logger +from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseKVStorage import json @@ -64,3 +64,79 @@ class RedisKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: # Redis handles persistence automatically pass + + async def delete(self, ids: list[str]) -> None: + """Delete entries with specified IDs + + Args: + ids: List of entry IDs to be deleted + """ + if not ids: + return + + pipe = self._redis.pipeline() + for id in ids: + pipe.delete(f"{self.namespace}:{id}") + + results = await pipe.execute() + deleted_count = sum(results) + logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}") + + async def delete_entity(self, entity_name: str) -> None: + """Delete an entity by name + + Args: + entity_name: Name of the entity to delete + """ + + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + + # Delete the entity + result = await self._redis.delete(f"{self.namespace}:{entity_id}") + + if result: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete all relations associated with an entity + + Args: + entity_name: Name of the entity whose relations should be deleted + """ + try: + # Get all keys in this namespace + cursor = 0 + relation_keys = [] + pattern = f"{self.namespace}:*" + + while True: + cursor, keys = await self._redis.scan(cursor, match=pattern) + + # For each key, get the value and check if it's related to entity_name + for key in keys: + value = await self._redis.get(key) + if value: + data = json.loads(value) + # Check if this is a relation involving the entity + if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name: + relation_keys.append(key) + + # Exit loop when cursor returns to 0 + if cursor == 0: + break + + # Delete the relation keys + if relation_keys: + deleted = await self._redis.delete(*relation_keys) + logger.debug(f"Deleted {deleted} relations for {entity_name}") + else: + logger.debug(f"No relations found for entity {entity_name}") + + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 51d1c365..f791d401 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -5,7 +5,7 @@ from typing import Any, Union, final import numpy as np -from lightrag.types import KnowledgeGraph +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage @@ -566,15 +566,148 @@ class TiDBGraphStorage(BaseGraphStorage): pass async def delete_node(self, node_id: str) -> None: - raise NotImplementedError - + """Delete a node and all its related edges + + Args: + node_id: The ID of the node to delete + """ + # First delete all edges related to this node + await self.db.execute(SQL_TEMPLATES["delete_node_edges"], + {"name": node_id, "workspace": self.db.workspace}) + + # Then delete the node itself + await self.db.execute(SQL_TEMPLATES["delete_node"], + {"name": node_id, "workspace": self.db.workspace}) + + logger.debug(f"Node {node_id} and its related edges have been deleted from the graph") + async def get_all_labels(self) -> list[str]: - raise NotImplementedError - + """Get all entity types (labels) in the database + + Returns: + List of labels sorted alphabetically + """ + result = await self.db.query( + SQL_TEMPLATES["get_all_labels"], + {"workspace": self.db.workspace}, + multirows=True + ) + + if not result: + return [] + + # Extract all labels + return [item["label"] for item in result] + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """ + Get a connected subgraph of nodes matching the specified label + Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) + + Args: + node_label: The node label to match + max_depth: Maximum depth of the subgraph + + Returns: + KnowledgeGraph object containing nodes and edges + """ + result = KnowledgeGraph() + MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + + # Get matching nodes + if node_label == "*": + # Handle special case, get all nodes + node_results = await self.db.query( + SQL_TEMPLATES["get_all_nodes"], + {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, + multirows=True + ) + else: + # Get nodes matching the label + label_pattern = f"%{node_label}%" + node_results = await self.db.query( + SQL_TEMPLATES["get_matching_nodes"], + {"workspace": self.db.workspace, "label_pattern": label_pattern}, + multirows=True + ) + + if not node_results: + logger.warning(f"No nodes found matching label {node_label}") + return result + + # Limit the number of returned nodes + if len(node_results) > MAX_GRAPH_NODES: + node_results = node_results[:MAX_GRAPH_NODES] + + # Extract node names for edge query + node_names = [node["name"] for node in node_results] + node_names_str = ",".join([f"'{name}'" for name in node_names]) + + # Add nodes to result + for node in node_results: + node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]} + result.nodes.append( + KnowledgeGraphNode( + id=node["name"], + labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], + properties=node_properties + ) + ) + + # Get related edges + edge_results = await self.db.query( + SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), + {"workspace": self.db.workspace}, + multirows=True + ) + + if edge_results: + # Add edges to result + for edge in edge_results: + # Only include edges related to selected nodes + if edge["source_name"] in node_names and edge["target_name"] in node_names: + edge_id = f"{edge['source_name']}-{edge['target_name']}" + edge_properties = {k: v for k, v in edge.items() + if k not in ["id", "source_name", "target_name"]} + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="RELATED", + source=edge["source_name"], + target=edge["target_name"], + properties=edge_properties + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node IDs to delete + """ + for node_id in nodes: + await self.delete_node(node_id) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to delete, each edge is a (source, target) tuple + """ + for source, target in edges: + await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], { + "source": source, + "target": target, + "workspace": self.db.workspace + }) N_T = { @@ -785,4 +918,39 @@ SQL_TEMPLATES = { weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), source_chunk_id = VALUES(source_chunk_id) """, + "delete_node": """ + DELETE FROM LIGHTRAG_GRAPH_NODES + WHERE name = :name AND workspace = :workspace + """, + "delete_node_edges": """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace + """, + "get_all_labels": """ + SELECT DISTINCT entity_type as label + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY entity_type + """, + "get_matching_nodes": """ + SELECT * FROM LIGHTRAG_GRAPH_NODES + WHERE name LIKE :label_pattern AND workspace = :workspace + ORDER BY name + """, + "get_all_nodes": """ + SELECT * FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace + ORDER BY name + LIMIT :max_nodes + """, + "get_related_edges": """ + SELECT * FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name IN (:node_names) OR target_name IN (:node_names)) + AND workspace = :workspace + """, + "remove_multiple_edges": """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE (source_name = :source AND target_name = :target) + AND workspace = :workspace + """ } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6f42003d..eeed8a70 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1399,6 +1399,54 @@ class LightRAG: ] ) + def delete_by_relation(self, source_entity: str, target_entity: str) -> None: + """Synchronously delete a relation between two entities. + + Args: + source_entity: Name of the source entity + target_entity: Name of the target entity + """ + loop = always_get_an_event_loop() + return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) + + async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: + """Asynchronously delete a relation between two entities. + + Args: + source_entity: Name of the source entity + target_entity: Name of the target entity + """ + try: + # Check if the relation exists + edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity) + if not edge_exists: + logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist") + return + + # Delete relation from vector database + relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-") + await self.relationships_vdb.delete([relation_id]) + + # Delete relation from knowledge graph + await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)]) + + logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'") + await self._delete_relation_done() + except Exception as e: + logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}") + + async def _delete_relation_done(self) -> None: + """Callback after relation deletion is complete""" + await asyncio.gather( + *[ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.relationships_vdb, + self.chunk_entity_relation_graph, + ] + ] + ) + def _get_content_summary(self, content: str, max_length: int = 100) -> str: """Get summary of document content From 81568f3badbba85294ae0fc2d759a6f7f1715706 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 15:53:20 +0800 Subject: [PATCH 26/32] fix linting --- lightrag/kg/age_impl.py | 54 ++++++++-------- lightrag/kg/chroma_impl.py | 14 ++-- lightrag/kg/gremlin_impl.py | 120 +++++++++++++++++++--------------- lightrag/kg/milvus_impl.py | 59 ++++++++--------- lightrag/kg/mongo_impl.py | 48 ++++++++------ lightrag/kg/oracle_impl.py | 108 ++++++++++++++++--------------- lightrag/kg/postgres_impl.py | 78 +++++++++++----------- lightrag/kg/qdrant_impl.py | 40 ++++++------ lightrag/kg/redis_impl.py | 39 ++++++----- lightrag/kg/tidb_impl.py | 121 ++++++++++++++++++++--------------- lightrag/lightrag.py | 40 ++++++++---- 11 files changed, 394 insertions(+), 327 deletions(-) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index c6b98221..22951554 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -619,7 +619,7 @@ class AGEStorage(BaseGraphStorage): node_id: The label of the node to delete """ entity_name_label = node_id.strip('"') - + query = """ MATCH (n:`{label}`) DETACH DELETE n @@ -650,18 +650,20 @@ class AGEStorage(BaseGraphStorage): for source, target in edges: entity_name_label_source = source.strip('"') entity_name_label_target = target.strip('"') - + query = """ MATCH (source:`{src_label}`)-[r]->(target:`{tgt_label}`) DELETE r """ params = { "src_label": AGEStorage._encode_graph_label(entity_name_label_source), - "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target) + "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), } try: await self._query(query, **params) - logger.debug(f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'") + logger.debug( + f"Deleted edge from '{entity_name_label_source}' to '{entity_name_label_target}'" + ) except Exception as e: logger.error(f"Error during edge deletion: {str(e)}") raise @@ -683,7 +685,7 @@ class AGEStorage(BaseGraphStorage): async def get_all_labels(self) -> list[str]: """Get all node labels in the database - + Returns: ["label1", "label2", ...] # Alphabetically sorted label list """ @@ -692,7 +694,7 @@ class AGEStorage(BaseGraphStorage): RETURN DISTINCT labels(n) AS node_labels """ results = await self._query(query) - + all_labels = [] for record in results: if record and "node_labels" in record: @@ -701,7 +703,7 @@ class AGEStorage(BaseGraphStorage): # Decode label decoded_label = AGEStorage._decode_graph_label(label) all_labels.append(decoded_label) - + # Remove duplicates and sort return sorted(list(set(all_labels))) @@ -719,7 +721,7 @@ class AGEStorage(BaseGraphStorage): Args: node_label: String to match in node labels (will match any node containing this string in its label) max_depth: Maximum depth of the graph. Defaults to 5. - + Returns: KnowledgeGraph: Complete connected subgraph for specified node """ @@ -727,7 +729,7 @@ class AGEStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - + # Handle special case for "*" label if node_label == "*": # Query all nodes and sort by degree @@ -741,7 +743,7 @@ class AGEStorage(BaseGraphStorage): """ params = {"max_nodes": max_graph_nodes} nodes_result = await self._query(query, **params) - + # Add nodes to result node_ids = [] for record in nodes_result: @@ -755,12 +757,12 @@ class AGEStorage(BaseGraphStorage): KnowledgeGraphNode( id=node_id, labels=[node_label], - properties=node_properties + properties=node_properties, ) ) seen_nodes.add(node_id) node_ids.append(node_id) - + # Query edges between these nodes if node_ids: edges_query = """ @@ -770,7 +772,7 @@ class AGEStorage(BaseGraphStorage): """ edges_params = {"node_ids": node_ids} edges_result = await self._query(edges_query, **edges_params) - + # Add edges to result for record in edges_result: if "r" in record and "a" in record and "b" in record: @@ -785,7 +787,7 @@ class AGEStorage(BaseGraphStorage): type="DIRECTED", source=source, target=target, - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) @@ -793,7 +795,7 @@ class AGEStorage(BaseGraphStorage): # For specific label, use partial matching entity_name_label = node_label.strip('"') encoded_label = AGEStorage._encode_graph_label(entity_name_label) - + # Find matching start nodes start_query = """ MATCH (n:`{label}`) @@ -801,17 +803,14 @@ class AGEStorage(BaseGraphStorage): """ start_params = {"label": encoded_label} start_nodes = await self._query(start_query, **start_params) - + if not start_nodes: logger.warning(f"No nodes found with label '{entity_name_label}'!") return result - + # Traverse graph from each start node for start_node_record in start_nodes: if "n" in start_node_record: - start_node = start_node_record["n"] - start_id = str(start_node.get("id", "")) - # Use BFS to traverse graph query = """ MATCH (start:`{label}`) @@ -823,25 +822,28 @@ class AGEStorage(BaseGraphStorage): """ params = {"label": encoded_label, "max_depth": max_depth} results = await self._query(query, **params) - + # Extract nodes and edges from results for record in results: if "path_nodes" in record: # Process nodes for node in record["path_nodes"]: node_id = str(node.get("id", "")) - if node_id not in seen_nodes and len(seen_nodes) < max_graph_nodes: + if ( + node_id not in seen_nodes + and len(seen_nodes) < max_graph_nodes + ): node_properties = {k: v for k, v in node.items()} node_label = node.get("label", "") result.nodes.append( KnowledgeGraphNode( id=node_id, labels=[node_label], - properties=node_properties + properties=node_properties, ) ) seen_nodes.add(node_id) - + if "path_rels" in record: # Process edges for rel in record["path_rels"]: @@ -856,11 +858,11 @@ class AGEStorage(BaseGraphStorage): type=rel.get("label", "DIRECTED"), source=source, target=target, - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) - + logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index d36e6d7c..ea4b31a1 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -194,7 +194,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): async def delete_entity(self, entity_name: str) -> None: """Delete an entity by its ID. - + Args: entity_name: The ID of the entity to delete """ @@ -206,24 +206,26 @@ class ChromaVectorDBStorage(BaseVectorStorage): raise async def delete_entity_relation(self, entity_name: str) -> None: - """Delete an entity and its relations by ID. + """Delete an entity and its relations by ID. In vector DB context, this is equivalent to delete_entity. - + Args: entity_name: The ID of the entity to delete """ await self.delete_entity(entity_name) - + async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ try: logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") self._collection.delete(ids=ids) - logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") raise diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 4d343bb5..ddb7559f 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -397,12 +397,12 @@ class GremlinStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """Delete a node with the specified entity_name - + Args: node_id: The entity_name of the node to delete """ entity_name = GremlinStorage._fix_name(node_id) - + query = f"""g .V().has('graph', {self.graph_name}) .has('entity_name', {entity_name}) @@ -413,7 +413,7 @@ class GremlinStorage(BaseGraphStorage): logger.debug( "{%s}: Deleted node with entity_name '%s'", inspect.currentframe().f_code.co_name, - entity_name + entity_name, ) except Exception as e: logger.error(f"Error during node deletion: {str(e)}") @@ -425,13 +425,13 @@ class GremlinStorage(BaseGraphStorage): """ Embed nodes using the specified algorithm. Currently, only node2vec is supported but never called. - + Args: algorithm: The name of the embedding algorithm to use - + Returns: A tuple of (embeddings, node_ids) - + Raises: NotImplementedError: If the specified algorithm is not supported ValueError: If the algorithm is not supported @@ -458,7 +458,7 @@ class GremlinStorage(BaseGraphStorage): logger.debug( "{%s}: Retrieved %d labels", inspect.currentframe().f_code.co_name, - len(labels) + len(labels), ) return labels except Exception as e: @@ -471,7 +471,7 @@ class GremlinStorage(BaseGraphStorage): """ Retrieve a connected subgraph of nodes where the entity_name includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - + Args: node_label: Entity name of the starting node max_depth: Maximum depth of the subgraph @@ -482,12 +482,12 @@ class GremlinStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - + # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - + entity_name = GremlinStorage._fix_name(node_label) - + # Handle special case for "*" label if node_label == "*": # For "*", get all nodes and their edges (limited by MAX_GRAPH_NODES) @@ -497,25 +497,27 @@ class GremlinStorage(BaseGraphStorage): .elementMap() """ nodes_result = await self._query(query) - + # Add nodes to result for node_data in nodes_result: - node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + node_id = node_data.get("entity_name", str(node_data.get("id", ""))) if str(node_id) in seen_nodes: continue - + # Create node with properties - node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} - + node_properties = { + k: v for k, v in node_data.items() if k not in ["id", "label"] + } + result.nodes.append( KnowledgeGraphNode( id=str(node_id), - labels=[str(node_id)], - properties=node_properties + labels=[str(node_id)], + properties=node_properties, ) ) seen_nodes.add(str(node_id)) - + # Get and add edges if nodes_result: query = f"""g @@ -530,30 +532,34 @@ class GremlinStorage(BaseGraphStorage): .by(elementMap()) """ edges_result = await self._query(query) - + for path in edges_result: if len(path) >= 3: # source -> edge -> target source = path[0] edge_data = path[1] target = path[2] - - source_id = source.get('entity_name', str(source.get('id', ''))) - target_id = target.get('entity_name', str(target.get('id', ''))) - + + source_id = source.get("entity_name", str(source.get("id", ""))) + target_id = target.get("entity_name", str(target.get("id", ""))) + edge_id = f"{source_id}-{target_id}" if edge_id in seen_edges: continue - + # Create edge with properties - edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} - + edge_properties = { + k: v + for k, v in edge_data.items() + if k not in ["id", "label"] + } + result.edges.append( KnowledgeGraphEdge( id=edge_id, type="DIRECTED", source=str(source_id), target=str(target_id), - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) @@ -570,30 +576,36 @@ class GremlinStorage(BaseGraphStorage): .elementMap() """ nodes_result = await self._query(query) - + # Add nodes to result for node_data in nodes_result: - node_id = node_data.get('entity_name', str(node_data.get('id', ''))) + node_id = node_data.get("entity_name", str(node_data.get("id", ""))) if str(node_id) in seen_nodes: continue - + # Create node with properties - node_properties = {k: v for k, v in node_data.items() if k not in ['id', 'label']} - + node_properties = { + k: v for k, v in node_data.items() if k not in ["id", "label"] + } + result.nodes.append( KnowledgeGraphNode( id=str(node_id), - labels=[str(node_id)], - properties=node_properties + labels=[str(node_id)], + properties=node_properties, ) ) seen_nodes.add(str(node_id)) - + # Get edges between the nodes in the result if nodes_result: - node_ids = [n.get('entity_name', str(n.get('id', ''))) for n in nodes_result] - node_ids_query = ", ".join([GremlinStorage._to_value_map(nid) for nid in node_ids]) - + node_ids = [ + n.get("entity_name", str(n.get("id", ""))) for n in nodes_result + ] + node_ids_query = ", ".join( + [GremlinStorage._to_value_map(nid) for nid in node_ids] + ) + query = f"""g .V().has('graph', {self.graph_name}) .has('entity_name', within({node_ids_query})) @@ -606,38 +618,42 @@ class GremlinStorage(BaseGraphStorage): .by(elementMap()) """ edges_result = await self._query(query) - + for path in edges_result: if len(path) >= 3: # source -> edge -> target source = path[0] edge_data = path[1] target = path[2] - - source_id = source.get('entity_name', str(source.get('id', ''))) - target_id = target.get('entity_name', str(target.get('id', ''))) - + + source_id = source.get("entity_name", str(source.get("id", ""))) + target_id = target.get("entity_name", str(target.get("id", ""))) + edge_id = f"{source_id}-{target_id}" if edge_id in seen_edges: continue - + # Create edge with properties - edge_properties = {k: v for k, v in edge_data.items() if k not in ['id', 'label']} - + edge_properties = { + k: v + for k, v in edge_data.items() + if k not in ["id", "label"] + } + result.edges.append( KnowledgeGraphEdge( id=edge_id, type="DIRECTED", source=str(source_id), target=str(target_id), - properties=edge_properties + properties=edge_properties, ) ) seen_edges.add(edge_id) - + logger.info( "Subgraph query successful | Node count: %d | Edge count: %d", len(result.nodes), - len(result.edges) + len(result.edges), ) return result @@ -659,7 +675,7 @@ class GremlinStorage(BaseGraphStorage): for source, target in edges: entity_name_source = GremlinStorage._fix_name(source) entity_name_target = GremlinStorage._fix_name(target) - + query = f"""g .V().has('graph', {self.graph_name}) .has('entity_name', {entity_name_source}) @@ -674,7 +690,7 @@ class GremlinStorage(BaseGraphStorage): "{%s}: Deleted edge from '%s' to '%s'", inspect.currentframe().f_code.co_name, entity_name_source, - entity_name_target + entity_name_target, ) except Exception as e: logger.error(f"Error during edge deletion: {str(e)}") diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 2ad4da18..7242f03d 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -125,83 +125,84 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def delete_entity(self, entity_name: str) -> None: """Delete an entity from the vector database - + Args: entity_name: The name of the entity to delete """ try: # Compute entity ID from name entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Delete the entity from Milvus collection result = self._client.delete( - collection_name=self.namespace, - pks=[entity_id] + collection_name=self.namespace, pks=[entity_id] ) - + if result and result.get("delete_count", 0) > 0: logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") - + except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: The name of the entity whose relations should be deleted """ try: # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' - + # Find all relations involving this entity results = self._client.query( - collection_name=self.namespace, - filter=expr, - output_fields=["id"] + collection_name=self.namespace, filter=expr, output_fields=["id"] ) - + if not results or len(results) == 0: logger.debug(f"No relations found for entity {entity_name}") return - + # Extract IDs of relations to delete relation_ids = [item["id"] for item in results] - logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") - + logger.debug( + f"Found {len(relation_ids)} relations for entity {entity_name}" + ) + # Delete the relations if relation_ids: delete_result = self._client.delete( - collection_name=self.namespace, - pks=relation_ids + collection_name=self.namespace, pks=relation_ids ) - - logger.debug(f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}") - + + logger.debug( + f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}" + ) + except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ try: # Delete vectors by IDs - result = self._client.delete( - collection_name=self.namespace, - pks=ids - ) - + result = self._client.delete(collection_name=self.namespace, pks=ids) + if result and result.get("delete_count", 0) > 0: - logger.debug(f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}" + ) else: logger.debug(f"No vectors were deleted from {self.namespace}") - + except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 3afd2b44..c2957502 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -804,16 +804,15 @@ class MongoGraphStorage(BaseGraphStorage): logger.info(f"Deleting {len(nodes)} nodes") if not nodes: return - + # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) await self.collection.update_many( - {}, - {"$pull": {"edges": {"target": {"$in": nodes}}}} + {}, {"$pull": {"edges": {"target": {"$in": nodes}}}} ) - + # 2. Delete the node documents await self.collection.delete_many({"_id": {"$in": nodes}}) - + logger.debug(f"Successfully deleted nodes: {nodes}") async def remove_edges(self, edges: list[tuple[str, str]]) -> None: @@ -825,20 +824,19 @@ class MongoGraphStorage(BaseGraphStorage): logger.info(f"Deleting {len(edges)} edges") if not edges: return - + update_tasks = [] for source, target in edges: # Remove edge pointing to target from source node's edges array update_tasks.append( self.collection.update_one( - {"_id": source}, - {"$pull": {"edges": {"target": target}}} + {"_id": source}, {"$pull": {"edges": {"target": target}}} ) ) - + if update_tasks: await asyncio.gather(*update_tasks) - + logger.debug(f"Successfully deleted edges: {edges}") @@ -987,23 +985,29 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") if not ids: return - + try: result = await self._data.delete_many({"_id": {"$in": ids}}) - logger.debug(f"Successfully deleted {result.deleted_count} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {result.deleted_count} vectors from {self.namespace}" + ) except PyMongoError as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {str(e)}") + logger.error( + f"Error while deleting vectors from {self.namespace}: {str(e)}" + ) async def delete_entity(self, entity_name: str) -> None: """Delete an entity by its name - + Args: entity_name: Name of the entity to delete """ try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + result = await self._data.delete_one({"_id": entity_id}) if result.deleted_count > 0: logger.debug(f"Successfully deleted entity {entity_name}") @@ -1014,7 +1018,7 @@ class MongoVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -1024,15 +1028,17 @@ class MongoVectorDBStorage(BaseVectorStorage): {"$or": [{"src_id": entity_name}, {"tgt_id": entity_name}]} ) relations = await relations_cursor.to_list(length=None) - + if not relations: logger.debug(f"No relations found for entity {entity_name}") return - + # Extract IDs of relations to delete relation_ids = [relation["_id"] for relation in relations] - logger.debug(f"Found {len(relation_ids)} relations for entity {entity_name}") - + logger.debug( + f"Found {len(relation_ids)} relations for entity {entity_name}" + ) + # Delete the relations result = await self._data.delete_many({"_id": {"$in": relation_ids}}) logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index d189679e..5dee1143 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -444,27 +444,29 @@ class OracleVectorDBStorage(BaseVectorStorage): async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ if not ids: return - + try: SQL = SQL_TEMPLATES["delete_vectors"].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} await self.db.execute(SQL, params) - logger.info(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + logger.info( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") raise async def delete_entity(self, entity_name: str) -> None: """Delete entity by name - + Args: entity_name: Name of the entity to delete """ @@ -479,7 +481,7 @@ class OracleVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations connected to an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -713,7 +715,7 @@ class OracleGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """Delete a node from the graph - + Args: node_id: ID of the node to delete """ @@ -722,33 +724,35 @@ class OracleGraphStorage(BaseGraphStorage): delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] params_relations = {"workspace": self.db.workspace, "entity_name": node_id} await self.db.execute(delete_relations_sql, params_relations) - + # Then delete the node itself delete_node_sql = SQL_TEMPLATES["delete_entity"] params_node = {"workspace": self.db.workspace, "entity_name": node_id} await self.db.execute(delete_node_sql, params_node) - - logger.info(f"Successfully deleted node {node_id} and all its relationships") + + logger.info( + f"Successfully deleted node {node_id} and all its relationships" + ) except Exception as e: logger.error(f"Error deleting node {node_id}: {e}") raise async def get_all_labels(self) -> list[str]: """Get all unique entity types (labels) in the graph - + Returns: List of unique entity types/labels """ try: SQL = """ - SELECT DISTINCT entity_type - FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace + SELECT DISTINCT entity_type + FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace ORDER BY entity_type """ params = {"workspace": self.db.workspace} results = await self.db.query(SQL, params, multirows=True) - + if results: labels = [row["entity_type"] for row in results] return labels @@ -762,26 +766,26 @@ class OracleGraphStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """Retrieve a connected subgraph starting from nodes matching the given label - + Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable. Prioritizes nodes by: 1. Nodes matching the specified label 2. Nodes directly connected to matching nodes 3. Node degree (number of connections) - + Args: node_label: Label to match for starting nodes (use "*" for all nodes) max_depth: Maximum depth of traversal from starting nodes - + Returns: KnowledgeGraph object containing nodes and edges """ result = KnowledgeGraph() - + try: # Define maximum number of nodes to return max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000)) - + if node_label == "*": # For "*" label, get all nodes up to the limit nodes_sql = """ @@ -791,30 +795,33 @@ class OracleGraphStorage(BaseGraphStorage): ORDER BY id FETCH FIRST :limit ROWS ONLY """ - nodes_params = {"workspace": self.db.workspace, "limit": max_graph_nodes} + nodes_params = { + "workspace": self.db.workspace, + "limit": max_graph_nodes, + } nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) else: # For specific label, find matching nodes and related nodes nodes_sql = """ WITH matching_nodes AS ( - SELECT name + SELECT name FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace + WHERE workspace = :workspace AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%') ) SELECT n.name, n.entity_type, n.description, n.source_chunk_id, CASE WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2 WHEN EXISTS ( - SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e + SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e WHERE workspace = :workspace AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes)) OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes))) ) THEN 1 ELSE 0 END AS priority, - (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e - WHERE workspace = :workspace + (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e + WHERE workspace = :workspace AND (e.source_name = n.name OR e.target_name = n.name)) AS degree FROM LIGHTRAG_GRAPH_NODES n WHERE workspace = :workspace @@ -822,43 +829,41 @@ class OracleGraphStorage(BaseGraphStorage): FETCH FIRST :limit ROWS ONLY """ nodes_params = { - "workspace": self.db.workspace, + "workspace": self.db.workspace, "node_label": node_label, - "limit": max_graph_nodes + "limit": max_graph_nodes, } nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) - + if not nodes: logger.warning(f"No nodes found matching '{node_label}'") return result - + # Create mapping of node IDs to be used to filter edges node_names = [node["name"] for node in nodes] - + # Add nodes to result seen_nodes = set() for node in nodes: node_id = node["name"] if node_id in seen_nodes: continue - + # Create node properties dictionary properties = { "entity_type": node["entity_type"], "description": node["description"] or "", - "source_id": node["source_chunk_id"] or "" + "source_id": node["source_chunk_id"] or "", } - + # Add node to result result.nodes.append( KnowledgeGraphNode( - id=node_id, - labels=[node["entity_type"]], - properties=properties + id=node_id, labels=[node["entity_type"]], properties=properties ) ) seen_nodes.add(node_id) - + # Get edges between these nodes edges_sql = """ SELECT source_name, target_name, weight, keywords, description, source_chunk_id @@ -868,30 +873,27 @@ class OracleGraphStorage(BaseGraphStorage): AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) ORDER BY id """ - edges_params = { - "workspace": self.db.workspace, - "node_names": node_names - } + edges_params = {"workspace": self.db.workspace, "node_names": node_names} edges = await self.db.query(edges_sql, edges_params, multirows=True) - + # Add edges to result seen_edges = set() for edge in edges: source = edge["source_name"] target = edge["target_name"] edge_id = f"{source}-{target}" - + if edge_id in seen_edges: continue - + # Create edge properties dictionary properties = { "weight": edge["weight"] or 0.0, "keywords": edge["keywords"] or "", "description": edge["description"] or "", - "source_id": edge["source_chunk_id"] or "" + "source_id": edge["source_chunk_id"] or "", } - + # Add edge to result result.edges.append( KnowledgeGraphEdge( @@ -899,18 +901,18 @@ class OracleGraphStorage(BaseGraphStorage): type="RELATED", source=source, target=target, - properties=properties + properties=properties, ) ) seen_edges.add(edge_id) - + logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) - + except Exception as e: logger.error(f"Error retrieving knowledge graph: {e}") - + return result @@ -1166,8 +1168,8 @@ SQL_TEMPLATES = { "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})", "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name", "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)", - "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph - MATCH (a) - WHERE a.workspace=:workspace AND a.name=:node_id + "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph + MATCH (a) + WHERE a.workspace=:workspace AND a.name=:node_id ACTION DELETE a)""", } diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 7ce2b427..54a59f5d 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -527,11 +527,15 @@ class PGVectorStorage(BaseVectorStorage): return ids_list = ",".join([f"'{id}'" for id in ids]) - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" - + delete_sql = ( + f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" + ) + try: await self.db.execute(delete_sql, {"workspace": self.db.workspace}) - logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") @@ -543,12 +547,11 @@ class PGVectorStorage(BaseVectorStorage): """ try: # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY WHERE workspace=$1 AND entity_name=$2""" - + await self.db.execute( - delete_sql, - {"workspace": self.db.workspace, "entity_name": entity_name} + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} ) logger.debug(f"Successfully deleted entity {entity_name}") except Exception as e: @@ -562,12 +565,11 @@ class PGVectorStorage(BaseVectorStorage): """ try: # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" - + await self.db.execute( - delete_sql, - {"workspace": self.db.workspace, "entity_name": entity_name} + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} ) logger.debug(f"Successfully deleted relations for entity {entity_name}") except Exception as e: @@ -1167,7 +1169,9 @@ class PGGraphStorage(BaseGraphStorage): Args: node_ids (list[str]): A list of node IDs to remove. """ - encoded_node_ids = [self._encode_graph_label(node_id.strip('"')) for node_id in node_ids] + encoded_node_ids = [ + self._encode_graph_label(node_id.strip('"')) for node_id in node_ids + ] node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) query = """SELECT * FROM cypher('%s', $$ @@ -1189,7 +1193,13 @@ class PGGraphStorage(BaseGraphStorage): Args: edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). """ - encoded_edges = [(self._encode_graph_label(src.strip('"')), self._encode_graph_label(tgt.strip('"'))) for src, tgt in edges] + encoded_edges = [ + ( + self._encode_graph_label(src.strip('"')), + self._encode_graph_label(tgt.strip('"')), + ) + for src, tgt in edges + ] edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) query = """SELECT * FROM cypher('%s', $$ @@ -1211,10 +1221,13 @@ class PGGraphStorage(BaseGraphStorage): Returns: list[str]: A list of all labels in the graph. """ - query = """SELECT * FROM cypher('%s', $$ + query = ( + """SELECT * FROM cypher('%s', $$ MATCH (n:Entity) RETURN DISTINCT n.node_id AS label - $$) AS (label text)""" % self.graph_name + $$) AS (label text)""" + % self.graph_name + ) results = await self._query(query) labels = [self._decode_graph_label(result["label"]) for result in results] @@ -1260,7 +1273,10 @@ class PGGraphStorage(BaseGraphStorage): OPTIONAL MATCH (n)-[r]->(m:Entity) RETURN n, r, m LIMIT %d - $$) AS (n agtype, r agtype, m agtype)""" % (self.graph_name, MAX_GRAPH_NODES) + $$) AS (n agtype, r agtype, m agtype)""" % ( + self.graph_name, + MAX_GRAPH_NODES, + ) else: encoded_node_label = self._encode_graph_label(node_label.strip('"')) query = """SELECT * FROM cypher('%s', $$ @@ -1268,7 +1284,12 @@ class PGGraphStorage(BaseGraphStorage): OPTIONAL MATCH p = (n)-[*..%d]-(m) RETURN nodes(p) AS nodes, relationships(p) AS relationships LIMIT %d - $$) AS (nodes agtype[], relationships agtype[])""" % (self.graph_name, encoded_node_label, max_depth, MAX_GRAPH_NODES) + $$) AS (nodes agtype[], relationships agtype[])""" % ( + self.graph_name, + encoded_node_label, + max_depth, + MAX_GRAPH_NODES, + ) results = await self._query(query) @@ -1305,29 +1326,6 @@ class PGGraphStorage(BaseGraphStorage): return kg - async def get_all_labels(self) -> list[str]: - """ - Get all node labels in the graph - Returns: - [label1, label2, ...] # Alphabetically sorted label list - """ - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - RETURN DISTINCT n.node_id AS label - ORDER BY label - $$) AS (label agtype)""" % (self.graph_name) - - try: - results = await self._query(query) - labels = [] - for record in results: - if record["label"]: - labels.append(self._decode_graph_label(record["label"])) - return labels - except Exception as e: - logger.error(f"Error getting all labels: {str(e)}") - return [] - async def drop(self) -> None: """Drop the storage""" drop_sql = SQL_TEMPLATES["drop_vdb_entity"] diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index e3488caa..c7d346e6 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -143,7 +143,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): async def delete(self, ids: List[str]) -> None: """Delete vectors with specified IDs - + Args: ids: List of vector IDs to be deleted """ @@ -156,30 +156,34 @@ class QdrantVectorDBStorage(BaseVectorStorage): points_selector=models.PointIdsList( points=qdrant_ids, ), - wait=True + wait=True, + ) + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) - logger.debug(f"Successfully deleted {len(ids)} vectors from {self.namespace}") except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") async def delete_entity(self, entity_name: str) -> None: """Delete an entity by name - + Args: entity_name: Name of the entity to delete """ try: # Generate the entity ID entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Delete the entity point from the collection self._client.delete( collection_name=self.namespace, points_selector=models.PointIdsList( points=[entity_id], ), - wait=True + wait=True, ) logger.debug(f"Successfully deleted entity {entity_name}") except Exception as e: @@ -187,7 +191,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -198,23 +202,21 @@ class QdrantVectorDBStorage(BaseVectorStorage): scroll_filter=models.Filter( should=[ models.FieldCondition( - key="src_id", - match=models.MatchValue(value=entity_name) + key="src_id", match=models.MatchValue(value=entity_name) ), models.FieldCondition( - key="tgt_id", - match=models.MatchValue(value=entity_name) - ) + key="tgt_id", match=models.MatchValue(value=entity_name) + ), ] ), with_payload=True, - limit=1000 # Adjust as needed for your use case + limit=1000, # Adjust as needed for your use case ) - + # Extract points that need to be deleted relation_points = results[0] ids_to_delete = [point.id for point in relation_points] - + if ids_to_delete: # Delete the relations self._client.delete( @@ -222,9 +224,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): points_selector=models.PointIdsList( points=ids_to_delete, ), - wait=True + wait=True, + ) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" ) - logger.debug(f"Deleted {len(ids_to_delete)} relations for {entity_name}") else: logger.debug(f"No relations found for entity {entity_name}") except Exception as e: diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index bb42b367..3feb4985 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -67,35 +67,39 @@ class RedisKVStorage(BaseKVStorage): async def delete(self, ids: list[str]) -> None: """Delete entries with specified IDs - + Args: ids: List of entry IDs to be deleted """ if not ids: return - + pipe = self._redis.pipeline() for id in ids: pipe.delete(f"{self.namespace}:{id}") - + results = await pipe.execute() deleted_count = sum(results) - logger.info(f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}") + logger.info( + f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" + ) async def delete_entity(self, entity_name: str) -> None: """Delete an entity by name - + Args: entity_name: Name of the entity to delete """ - + try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") - + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Delete the entity result = await self._redis.delete(f"{self.namespace}:{entity_id}") - + if result: logger.debug(f"Successfully deleted entity {entity_name}") else: @@ -105,7 +109,7 @@ class RedisKVStorage(BaseKVStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity - + Args: entity_name: Name of the entity whose relations should be deleted """ @@ -114,29 +118,32 @@ class RedisKVStorage(BaseKVStorage): cursor = 0 relation_keys = [] pattern = f"{self.namespace}:*" - + while True: cursor, keys = await self._redis.scan(cursor, match=pattern) - + # For each key, get the value and check if it's related to entity_name for key in keys: value = await self._redis.get(key) if value: data = json.loads(value) # Check if this is a relation involving the entity - if data.get("src_id") == entity_name or data.get("tgt_id") == entity_name: + if ( + data.get("src_id") == entity_name + or data.get("tgt_id") == entity_name + ): relation_keys.append(key) - + # Exit loop when cursor returns to 0 if cursor == 0: break - + # Delete the relation keys if relation_keys: deleted = await self._redis.delete(*relation_keys) logger.debug(f"Deleted {deleted} relations for {entity_name}") else: logger.debug(f"No relations found for entity {entity_name}") - + except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index f791d401..684c30d7 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -567,62 +567,68 @@ class TiDBGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """Delete a node and all its related edges - + Args: node_id: The ID of the node to delete """ # First delete all edges related to this node - await self.db.execute(SQL_TEMPLATES["delete_node_edges"], - {"name": node_id, "workspace": self.db.workspace}) - + await self.db.execute( + SQL_TEMPLATES["delete_node_edges"], + {"name": node_id, "workspace": self.db.workspace}, + ) + # Then delete the node itself - await self.db.execute(SQL_TEMPLATES["delete_node"], - {"name": node_id, "workspace": self.db.workspace}) - - logger.debug(f"Node {node_id} and its related edges have been deleted from the graph") - + await self.db.execute( + SQL_TEMPLATES["delete_node"], + {"name": node_id, "workspace": self.db.workspace}, + ) + + logger.debug( + f"Node {node_id} and its related edges have been deleted from the graph" + ) + async def get_all_labels(self) -> list[str]: """Get all entity types (labels) in the database - + Returns: List of labels sorted alphabetically """ result = await self.db.query( - SQL_TEMPLATES["get_all_labels"], - {"workspace": self.db.workspace}, - multirows=True + SQL_TEMPLATES["get_all_labels"], + {"workspace": self.db.workspace}, + multirows=True, ) - + if not result: return [] - + # Extract all labels return [item["label"] for item in result] - + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ Get a connected subgraph of nodes matching the specified label Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) - + Args: node_label: The node label to match max_depth: Maximum depth of the subgraph - + Returns: KnowledgeGraph object containing nodes and edges """ result = KnowledgeGraph() MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) - + # Get matching nodes if node_label == "*": # Handle special case, get all nodes node_results = await self.db.query( SQL_TEMPLATES["get_all_nodes"], {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, - multirows=True + multirows=True, ) else: # Get nodes matching the label @@ -630,84 +636,93 @@ class TiDBGraphStorage(BaseGraphStorage): node_results = await self.db.query( SQL_TEMPLATES["get_matching_nodes"], {"workspace": self.db.workspace, "label_pattern": label_pattern}, - multirows=True + multirows=True, ) - + if not node_results: logger.warning(f"No nodes found matching label {node_label}") return result - + # Limit the number of returned nodes if len(node_results) > MAX_GRAPH_NODES: node_results = node_results[:MAX_GRAPH_NODES] - + # Extract node names for edge query node_names = [node["name"] for node in node_results] node_names_str = ",".join([f"'{name}'" for name in node_names]) - + # Add nodes to result for node in node_results: - node_properties = {k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]} + node_properties = { + k: v for k, v in node.items() if k not in ["id", "name", "entity_type"] + } result.nodes.append( KnowledgeGraphNode( id=node["name"], - labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], - properties=node_properties + labels=[node["entity_type"]] + if node.get("entity_type") + else [node["name"]], + properties=node_properties, ) ) - + # Get related edges edge_results = await self.db.query( SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), {"workspace": self.db.workspace}, - multirows=True + multirows=True, ) - + if edge_results: # Add edges to result for edge in edge_results: # Only include edges related to selected nodes - if edge["source_name"] in node_names and edge["target_name"] in node_names: + if ( + edge["source_name"] in node_names + and edge["target_name"] in node_names + ): edge_id = f"{edge['source_name']}-{edge['target_name']}" - edge_properties = {k: v for k, v in edge.items() - if k not in ["id", "source_name", "target_name"]} - + edge_properties = { + k: v + for k, v in edge.items() + if k not in ["id", "source_name", "target_name"] + } + result.edges.append( KnowledgeGraphEdge( id=edge_id, type="RELATED", source=edge["source_name"], target=edge["target_name"], - properties=edge_properties + properties=edge_properties, ) ) - + logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result - + async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes - + Args: nodes: List of node IDs to delete """ for node_id in nodes: await self.delete_node(node_id) - + async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges - + Args: edges: List of edges to delete, each edge is a (source, target) tuple """ for source, target in edges: - await self.db.execute(SQL_TEMPLATES["remove_multiple_edges"], { - "source": source, - "target": target, - "workspace": self.db.workspace - }) + await self.db.execute( + SQL_TEMPLATES["remove_multiple_edges"], + {"source": source, "target": target, "workspace": self.db.workspace}, + ) N_T = { @@ -919,26 +934,26 @@ SQL_TEMPLATES = { source_chunk_id = VALUES(source_chunk_id) """, "delete_node": """ - DELETE FROM LIGHTRAG_GRAPH_NODES + DELETE FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace """, "delete_node_edges": """ - DELETE FROM LIGHTRAG_GRAPH_EDGES + DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace """, "get_all_labels": """ - SELECT DISTINCT entity_type as label - FROM LIGHTRAG_GRAPH_NODES + SELECT DISTINCT entity_type as label + FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace ORDER BY entity_type """, "get_matching_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES + SELECT * FROM LIGHTRAG_GRAPH_NODES WHERE name LIKE :label_pattern AND workspace = :workspace ORDER BY name """, "get_all_nodes": """ - SELECT * FROM LIGHTRAG_GRAPH_NODES + SELECT * FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace ORDER BY name LIMIT :max_nodes @@ -952,5 +967,5 @@ SQL_TEMPLATES = { DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE (source_name = :source AND target_name = :target) AND workspace = :workspace - """ + """, } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index eeed8a70..a8034ddd 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1401,40 +1401,54 @@ class LightRAG: def delete_by_relation(self, source_entity: str, target_entity: str) -> None: """Synchronously delete a relation between two entities. - + Args: source_entity: Name of the source entity target_entity: Name of the target entity """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.adelete_by_relation(source_entity, target_entity)) + return loop.run_until_complete( + self.adelete_by_relation(source_entity, target_entity) + ) async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: """Asynchronously delete a relation between two entities. - + Args: source_entity: Name of the source entity target_entity: Name of the target entity """ try: # Check if the relation exists - edge_exists = await self.chunk_entity_relation_graph.has_edge(source_entity, target_entity) + edge_exists = await self.chunk_entity_relation_graph.has_edge( + source_entity, target_entity + ) if not edge_exists: - logger.warning(f"Relation from '{source_entity}' to '{target_entity}' does not exist") + logger.warning( + f"Relation from '{source_entity}' to '{target_entity}' does not exist" + ) return - + # Delete relation from vector database - relation_id = compute_mdhash_id(source_entity + target_entity, prefix="rel-") + relation_id = compute_mdhash_id( + source_entity + target_entity, prefix="rel-" + ) await self.relationships_vdb.delete([relation_id]) - + # Delete relation from knowledge graph - await self.chunk_entity_relation_graph.remove_edges([(source_entity, target_entity)]) - - logger.info(f"Successfully deleted relation from '{source_entity}' to '{target_entity}'") + await self.chunk_entity_relation_graph.remove_edges( + [(source_entity, target_entity)] + ) + + logger.info( + f"Successfully deleted relation from '{source_entity}' to '{target_entity}'" + ) await self._delete_relation_done() except Exception as e: - logger.error(f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}") - + logger.error( + f"Error while deleting relation from '{source_entity}' to '{target_entity}': {e}" + ) + async def _delete_relation_done(self) -> None: """Callback after relation deletion is complete""" await asyncio.gather( From 4ebaf8026b85e5794c132cf4aad92ad582b05605 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 16:11:13 +0800 Subject: [PATCH 27/32] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 5dee1143..754c3491 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -737,6 +737,64 @@ class OracleGraphStorage(BaseGraphStorage): logger.error(f"Error deleting node {node_id}: {e}") raise + async def remove_nodes(self, nodes: list[str]) -> None: + """Delete multiple nodes from the graph + + Args: + nodes: List of node IDs to be deleted + """ + if not nodes: + return + + try: + for node in nodes: + # For each node, first delete all its relationships + delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] + params_relations = {"workspace": self.db.workspace, "entity_name": node} + await self.db.execute(delete_relations_sql, params_relations) + + # Then delete the node itself + delete_node_sql = SQL_TEMPLATES["delete_entity"] + params_node = {"workspace": self.db.workspace, "entity_name": node} + await self.db.execute(delete_node_sql, params_node) + + logger.info(f"Successfully deleted {len(nodes)} nodes and their relationships") + except Exception as e: + logger.error(f"Error during batch node deletion: {e}") + raise + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """Delete multiple edges from the graph + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + if not edges: + return + + try: + for source, target in edges: + # Check if the edge exists before attempting to delete + if await self.has_edge(source, target): + # Delete the edge using a SQL query that matches both source and target + delete_edge_sql = """ + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name = :source_name + AND target_name = :target_name + """ + params = { + "workspace": self.db.workspace, + "source_name": source, + "target_name": target + } + await self.db.execute(delete_edge_sql, params) + + logger.info(f"Successfully deleted {len(edges)} edges from the graph") + except Exception as e: + logger.error(f"Error during batch edge deletion: {e}") + raise + async def get_all_labels(self) -> list[str]: """Get all unique entity types (labels) in the graph From 1ee6c23a53a3418be88e2ecdb5bc8637271712b4 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 16:12:27 +0800 Subject: [PATCH 28/32] fix linting --- lightrag/kg/oracle_impl.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 754c3491..d105aa54 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -739,57 +739,59 @@ class OracleGraphStorage(BaseGraphStorage): async def remove_nodes(self, nodes: list[str]) -> None: """Delete multiple nodes from the graph - + Args: nodes: List of node IDs to be deleted """ if not nodes: return - + try: for node in nodes: # For each node, first delete all its relationships delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] params_relations = {"workspace": self.db.workspace, "entity_name": node} await self.db.execute(delete_relations_sql, params_relations) - + # Then delete the node itself delete_node_sql = SQL_TEMPLATES["delete_entity"] params_node = {"workspace": self.db.workspace, "entity_name": node} await self.db.execute(delete_node_sql, params_node) - - logger.info(f"Successfully deleted {len(nodes)} nodes and their relationships") + + logger.info( + f"Successfully deleted {len(nodes)} nodes and their relationships" + ) except Exception as e: logger.error(f"Error during batch node deletion: {e}") raise async def remove_edges(self, edges: list[tuple[str, str]]) -> None: """Delete multiple edges from the graph - + Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ if not edges: return - + try: for source, target in edges: # Check if the edge exists before attempting to delete if await self.has_edge(source, target): # Delete the edge using a SQL query that matches both source and target delete_edge_sql = """ - DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE workspace = :workspace - AND source_name = :source_name + DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace + AND source_name = :source_name AND target_name = :target_name """ params = { "workspace": self.db.workspace, "source_name": source, - "target_name": target + "target_name": target, } await self.db.execute(delete_edge_sql, params) - + logger.info(f"Successfully deleted {len(edges)} edges from the graph") except Exception as e: logger.error(f"Error during batch edge deletion: {e}") From 4e59a293fe792f3733a0ff836b6b7fe714ba5526 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Tue, 4 Mar 2025 16:19:23 +0800 Subject: [PATCH 29/32] Update __init__.py --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 2d660928..e4cb3e63 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.2.3" +__version__ = "1.2.4" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 3264f6a118f572467f84986868217f258a299dfb Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 16:36:58 +0800 Subject: [PATCH 30/32] Update delete_by_doc_id --- lightrag/lightrag.py | 92 +++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 40 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a8034ddd..e8e468af 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1555,51 +1555,57 @@ class LightRAG: await self.text_chunks.delete(chunk_ids) # 5. Find and process entities and relationships that have these chunks as source - # Get all nodes in the graph - nodes = self.chunk_entity_relation_graph._graph.nodes(data=True) - edges = self.chunk_entity_relation_graph._graph.edges(data=True) - - # Track which entities and relationships need to be deleted or updated + # Get all nodes and edges from the graph storage using storage-agnostic methods entities_to_delete = set() entities_to_update = {} # entity_name -> new_source_id relationships_to_delete = set() relationships_to_update = {} # (src, tgt) -> new_source_id - # Process entities - for node, data in nodes: - if "source_id" in data: + # Process entities - use storage-agnostic methods + all_labels = await self.chunk_entity_relation_graph.get_all_labels() + for node_label in all_labels: + node_data = await self.chunk_entity_relation_graph.get_node(node_label) + if node_data and "source_id" in node_data: # Split source_id using GRAPH_FIELD_SEP - sources = set(data["source_id"].split(GRAPH_FIELD_SEP)) + sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) sources.difference_update(chunk_ids) if not sources: - entities_to_delete.add(node) + entities_to_delete.add(node_label) logger.debug( - f"Entity {node} marked for deletion - no remaining sources" + f"Entity {node_label} marked for deletion - no remaining sources" ) else: new_source_id = GRAPH_FIELD_SEP.join(sources) - entities_to_update[node] = new_source_id + entities_to_update[node_label] = new_source_id logger.debug( - f"Entity {node} will be updated with new source_id: {new_source_id}" + f"Entity {node_label} will be updated with new source_id: {new_source_id}" ) # Process relationships - for src, tgt, data in edges: - if "source_id" in data: - # Split source_id using GRAPH_FIELD_SEP - sources = set(data["source_id"].split(GRAPH_FIELD_SEP)) - sources.difference_update(chunk_ids) - if not sources: - relationships_to_delete.add((src, tgt)) - logger.debug( - f"Relationship {src}-{tgt} marked for deletion - no remaining sources" - ) - else: - new_source_id = GRAPH_FIELD_SEP.join(sources) - relationships_to_update[(src, tgt)] = new_source_id - logger.debug( - f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}" + for node_label in all_labels: + node_edges = await self.chunk_entity_relation_graph.get_node_edges( + node_label + ) + if node_edges: + for src, tgt in node_edges: + edge_data = await self.chunk_entity_relation_graph.get_edge( + src, tgt ) + if edge_data and "source_id" in edge_data: + # Split source_id using GRAPH_FIELD_SEP + sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) + sources.difference_update(chunk_ids) + if not sources: + relationships_to_delete.add((src, tgt)) + logger.debug( + f"Relationship {src}-{tgt} marked for deletion - no remaining sources" + ) + else: + new_source_id = GRAPH_FIELD_SEP.join(sources) + relationships_to_update[(src, tgt)] = new_source_id + logger.debug( + f"Relationship {src}-{tgt} will be updated with new source_id: {new_source_id}" + ) # Delete entities if entities_to_delete: @@ -1613,12 +1619,15 @@ class LightRAG: # Update entities for entity, new_source_id in entities_to_update.items(): - node_data = self.chunk_entity_relation_graph._graph.nodes[entity] - node_data["source_id"] = new_source_id - await self.chunk_entity_relation_graph.upsert_node(entity, node_data) - logger.debug( - f"Updated entity {entity} with new source_id: {new_source_id}" - ) + node_data = await self.chunk_entity_relation_graph.get_node(entity) + if node_data: + node_data["source_id"] = new_source_id + await self.chunk_entity_relation_graph.upsert_node( + entity, node_data + ) + logger.debug( + f"Updated entity {entity} with new source_id: {new_source_id}" + ) # Delete relationships if relationships_to_delete: @@ -1636,12 +1645,15 @@ class LightRAG: # Update relationships for (src, tgt), new_source_id in relationships_to_update.items(): - edge_data = self.chunk_entity_relation_graph._graph.edges[src, tgt] - edge_data["source_id"] = new_source_id - await self.chunk_entity_relation_graph.upsert_edge(src, tgt, edge_data) - logger.debug( - f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}" - ) + edge_data = await self.chunk_entity_relation_graph.get_edge(src, tgt) + if edge_data: + edge_data["source_id"] = new_source_id + await self.chunk_entity_relation_graph.upsert_edge( + src, tgt, edge_data + ) + logger.debug( + f"Updated relationship {src}-{tgt} with new source_id: {new_source_id}" + ) # 6. Delete original document and status await self.full_docs.delete([doc_id]) From a688b8822a8a9eb3853781bb5c71029a22aa5396 Mon Sep 17 00:00:00 2001 From: Brocowlee Date: Tue, 4 Mar 2025 10:09:47 +0000 Subject: [PATCH 31/32] [EVO] Add language configuration to environment and argument parsing --- env.example | 2 +- lightrag/api/lightrag_server.py | 3 +++ lightrag/api/utils_api.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/env.example b/env.example index 112676c6..d0c03a05 100644 --- a/env.example +++ b/env.example @@ -47,7 +47,7 @@ # CHUNK_OVERLAP_SIZE=100 # MAX_TOKENS=32768 # Max tokens send to LLM for summarization # MAX_TOKEN_SUMMARY=500 # Max tokens for entity or relations summary -# SUMMARY_LANGUAGE=English +# LANGUAGE=English # MAX_EMBED_TOKENS=8192 ### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..93201a20 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -331,6 +331,9 @@ def create_app(args): }, log_level=args.log_level, namespace_prefix=args.namespace_prefix, + addon_params={ + "language": args.language, + }, auto_manage_storages_states=False, ) else: # azure_openai diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index ed1250d4..f865682b 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -340,6 +340,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) + args.language = get_env_value("LANGUAGE", "English") ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name From f1ad55244abb482bef83ef2a9b340ef305e025de Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Tue, 4 Mar 2025 14:44:12 +0100 Subject: [PATCH 32/32] linting --- lightrag/api/run_with_gunicorn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 231a1727..cf9b3b91 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -15,6 +15,7 @@ from dotenv import load_dotenv # This update allows the user to put a different.env file for each lightrag folder load_dotenv(".env") + def check_and_install_dependencies(): """Check and install required dependencies""" required_packages = [