diff --git a/.gitignore b/.gitignore index e8130e18..3eb55bd3 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,6 @@ dickens/ book.txt lightrag-dev/ gui/ + +# unit-test files +test_* diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ed1e546f..f6cee412 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1,44 +1,46 @@ +""" +LightRAG FastAPI Server +""" + from fastapi import ( FastAPI, - HTTPException, - File, - UploadFile, - BackgroundTasks, + Depends, ) +from fastapi.responses import FileResponse import asyncio import threading import os -import json -import re from fastapi.staticfiles import StaticFiles import logging -import argparse -from typing import List, Any, Literal, Optional, Dict -from pydantic import BaseModel, Field, field_validator +from typing import Dict from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception, ASCIIColors -import sys -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader +import configparser +from ascii_colors import ASCIIColors from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager -from starlette.status import HTTP_403_FORBIDDEN -import pipmaster as pm from dotenv import load_dotenv -import configparser -import traceback -from datetime import datetime -from lightrag import LightRAG, QueryParam -from lightrag.base import DocProcessingStatus, DocStatus +from .utils_api import ( + get_api_key_dependency, + parse_args, + get_default_host, + display_splash_screen, +) + +from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc from lightrag.utils import logger -from .ollama_api import OllamaAPI, ollama_server_infos +from .routers.document_routes import ( + DocumentManager, + create_document_routes, + run_scanning_process, +) +from .routers.query_routes import create_query_routes +from .routers.graph_routes import create_graph_routes +from .routers.ollama_api import OllamaAPI # Load environment variables try: @@ -50,13 +52,8 @@ except Exception as e: config = configparser.ConfigParser() config.read("config.ini") - -class DefaultRAGStorageConfig: - KV_STORAGE = "JsonKVStorage" - VECTOR_STORAGE = "NanoVectorDBStorage" - GRAPH_STORAGE = "NetworkXStorage" - DOC_STATUS_STORAGE = "JsonDocStatusStorage" - +# Global configuration +global_top_k = 60 # default value # Global progress tracker scan_progress: Dict = { @@ -71,819 +68,16 @@ scan_progress: Dict = { progress_lock = threading.Lock() -def estimate_tokens(text: str) -> int: - """Estimate the number of tokens in text - Chinese characters: approximately 1.5 tokens per character - English characters: approximately 0.25 tokens per character - """ - # Use regex to match Chinese and non-Chinese characters separately - chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) - non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) - - # Calculate estimated token count - tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 - - return int(tokens) - - -def get_default_host(binding_type: str) -> str: - default_hosts = { - "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), - "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), - "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), - "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), - } - return default_hosts.get( - binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") - ) # fallback to ollama if unknown - - -def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any: - """ - Get value from environment variable with type conversion - - Args: - env_key (str): Environment variable key - default (Any): Default value if env variable is not set - value_type (type): Type to convert the value to - - Returns: - Any: Converted value from environment or default - """ - value = os.getenv(env_key) - if value is None: - return default - - if value_type is bool: - return value.lower() in ("true", "1", "yes", "t", "on") - try: - return value_type(value) - except ValueError: - return default - - -def display_splash_screen(args: argparse.Namespace) -> None: - """ - Display a colorful splash screen showing LightRAG server configuration - - Args: - args: Parsed command line arguments - """ - # Banner - ASCIIColors.cyan(f""" - ╔══════════════════════════════════════════════════════════════╗ - ║ 🚀 LightRAG Server v{__api_version__} ║ - ║ Fast, Lightweight RAG Server Implementation ║ - ╚══════════════════════════════════════════════════════════════╝ - """) - - # Server Configuration - ASCIIColors.magenta("\n📡 Server Configuration:") - ASCIIColors.white(" ├─ Host: ", end="") - ASCIIColors.yellow(f"{args.host}") - ASCIIColors.white(" ├─ Port: ", end="") - ASCIIColors.yellow(f"{args.port}") - ASCIIColors.white(" ├─ CORS Origins: ", end="") - ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") - ASCIIColors.white(" ├─ SSL Enabled: ", end="") - ASCIIColors.yellow(f"{args.ssl}") - ASCIIColors.white(" └─ API Key: ", end="") - ASCIIColors.yellow("Set" if args.key else "Not Set") - if args.ssl: - ASCIIColors.white(" ├─ SSL Cert: ", end="") - ASCIIColors.yellow(f"{args.ssl_certfile}") - ASCIIColors.white(" └─ SSL Key: ", end="") - ASCIIColors.yellow(f"{args.ssl_keyfile}") - - # Directory Configuration - ASCIIColors.magenta("\n📂 Directory Configuration:") - ASCIIColors.white(" ├─ Working Directory: ", end="") - ASCIIColors.yellow(f"{args.working_dir}") - ASCIIColors.white(" └─ Input Directory: ", end="") - ASCIIColors.yellow(f"{args.input_dir}") - - # LLM Configuration - ASCIIColors.magenta("\n🤖 LLM Configuration:") - ASCIIColors.white(" ├─ Binding: ", end="") - ASCIIColors.yellow(f"{args.llm_binding}") - ASCIIColors.white(" ├─ Host: ", end="") - ASCIIColors.yellow(f"{args.llm_binding_host}") - ASCIIColors.white(" └─ Model: ", end="") - ASCIIColors.yellow(f"{args.llm_model}") - - # Embedding Configuration - ASCIIColors.magenta("\n📊 Embedding Configuration:") - ASCIIColors.white(" ├─ Binding: ", end="") - ASCIIColors.yellow(f"{args.embedding_binding}") - ASCIIColors.white(" ├─ Host: ", end="") - ASCIIColors.yellow(f"{args.embedding_binding_host}") - ASCIIColors.white(" ├─ Model: ", end="") - ASCIIColors.yellow(f"{args.embedding_model}") - ASCIIColors.white(" └─ Dimensions: ", end="") - ASCIIColors.yellow(f"{args.embedding_dim}") - - # RAG Configuration - ASCIIColors.magenta("\n⚙️ RAG Configuration:") - ASCIIColors.white(" ├─ Max Async Operations: ", end="") - ASCIIColors.yellow(f"{args.max_async}") - ASCIIColors.white(" ├─ Max Tokens: ", end="") - ASCIIColors.yellow(f"{args.max_tokens}") - ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") - ASCIIColors.yellow(f"{args.max_embed_tokens}") - ASCIIColors.white(" ├─ Chunk Size: ", end="") - ASCIIColors.yellow(f"{args.chunk_size}") - ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="") - ASCIIColors.yellow(f"{args.chunk_overlap_size}") - ASCIIColors.white(" ├─ History Turns: ", end="") - ASCIIColors.yellow(f"{args.history_turns}") - ASCIIColors.white(" ├─ Cosine Threshold: ", end="") - ASCIIColors.yellow(f"{args.cosine_threshold}") - ASCIIColors.white(" └─ Top-K: ", end="") - ASCIIColors.yellow(f"{args.top_k}") - - # System Configuration - ASCIIColors.magenta("\n💾 Storage Configuration:") - ASCIIColors.white(" ├─ KV Storage: ", end="") - ASCIIColors.yellow(f"{args.kv_storage}") - ASCIIColors.white(" ├─ Vector Storage: ", end="") - ASCIIColors.yellow(f"{args.vector_storage}") - ASCIIColors.white(" ├─ Graph Storage: ", end="") - ASCIIColors.yellow(f"{args.graph_storage}") - ASCIIColors.white(" └─ Document Status Storage: ", end="") - ASCIIColors.yellow(f"{args.doc_status_storage}") - - ASCIIColors.magenta("\n🛠️ System Configuration:") - ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") - ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") - ASCIIColors.white(" ├─ Log Level: ", end="") - ASCIIColors.yellow(f"{args.log_level}") - ASCIIColors.white(" ├─ Verbose Debug: ", end="") - ASCIIColors.yellow(f"{args.verbose}") - ASCIIColors.white(" └─ Timeout: ", end="") - ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") - - # Server Status - ASCIIColors.green("\n✨ Server starting up...\n") - - # Server Access Information - protocol = "https" if args.ssl else "http" - if args.host == "0.0.0.0": - ASCIIColors.magenta("\n🌐 Server Access Information:") - ASCIIColors.white(" ├─ Local Access: ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}") - ASCIIColors.white(" ├─ Remote Access: ", end="") - ASCIIColors.yellow(f"{protocol}://:{args.port}") - ASCIIColors.white(" ├─ API Documentation (local): ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs") - ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc") - ASCIIColors.white(" └─ WebUI (local): ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui") - - ASCIIColors.yellow("\n📝 Note:") - ASCIIColors.white(""" Since the server is running on 0.0.0.0: - - Use 'localhost' or '127.0.0.1' for local access - - Use your machine's IP address for remote access - - To find your IP address: - • Windows: Run 'ipconfig' in terminal - • Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal - """) - else: - base_url = f"{protocol}://{args.host}:{args.port}" - ASCIIColors.magenta("\n🌐 Server Access Information:") - ASCIIColors.white(" ├─ Base URL: ", end="") - ASCIIColors.yellow(f"{base_url}") - ASCIIColors.white(" ├─ API Documentation: ", end="") - ASCIIColors.yellow(f"{base_url}/docs") - ASCIIColors.white(" └─ Alternative Documentation: ", end="") - ASCIIColors.yellow(f"{base_url}/redoc") - - # Usage Examples - ASCIIColors.magenta("\n📚 Quick Start Guide:") - ASCIIColors.cyan(""" - 1. Access the Swagger UI: - Open your browser and navigate to the API documentation URL above - - 2. API Authentication:""") - if args.key: - ASCIIColors.cyan(""" Add the following header to your requests: - X-API-Key: - """) - else: - ASCIIColors.cyan(" No authentication required\n") - - ASCIIColors.cyan(""" 3. Basic Operations: - - POST /upload_document: Upload new documents to RAG - - POST /query: Query your document collection - - GET /collections: List available collections - - 4. Monitor the server: - - Check server logs for detailed operation information - - Use healthcheck endpoint: GET /health - """) - - # Security Notice - if args.key: - ASCIIColors.yellow("\n⚠️ Security Notice:") - ASCIIColors.white(""" API Key authentication is enabled. - Make sure to include the X-API-Key header in all your requests. - """) - - ASCIIColors.green("Server is ready to accept connections! 🚀\n") - - # Ensure splash output flush to system log - sys.stdout.flush() - - -def parse_args() -> argparse.Namespace: - """ - Parse command line arguments with environment variable fallback - - Returns: - argparse.Namespace: Parsed arguments - """ - - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) - - parser.add_argument( - "--kv-storage", - default=get_env_value( - "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE - ), - help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})", - ) - parser.add_argument( - "--doc-status-storage", - default=get_env_value( - "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE - ), - help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", - ) - parser.add_argument( - "--graph-storage", - default=get_env_value( - "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE - ), - help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", - ) - parser.add_argument( - "--vector-storage", - default=get_env_value( - "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE - ), - help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", - ) - - # Bindings configuration - parser.add_argument( - "--llm-binding", - default=get_env_value("LLM_BINDING", "ollama"), - help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", - ) - parser.add_argument( - "--embedding-binding", - default=get_env_value("EMBEDDING_BINDING", "ollama"), - help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", - ) - - # Server configuration - parser.add_argument( - "--host", - default=get_env_value("HOST", "0.0.0.0"), - help="Server host (default: from env or 0.0.0.0)", - ) - parser.add_argument( - "--port", - type=int, - default=get_env_value("PORT", 9621, int), - help="Server port (default: from env or 9621)", - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default=get_env_value("WORKING_DIR", "./rag_storage"), - help="Working directory for RAG storage (default: from env or ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default=get_env_value("INPUT_DIR", "./inputs"), - help="Directory containing input documents (default: from env or ./inputs)", - ) - - # LLM Model configuration - parser.add_argument( - "--llm-binding-host", - default=get_env_value("LLM_BINDING_HOST", None), - help="LLM server host URL. If not provided, defaults based on llm-binding:\n" - + "- ollama: http://localhost:11434\n" - + "- lollms: http://localhost:9600\n" - + "- openai: https://api.openai.com/v1", - ) - - default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None) - - parser.add_argument( - "--llm-binding-api-key", - default=default_llm_api_key, - help="llm server API key (default: from env or empty string)", - ) - - parser.add_argument( - "--llm-model", - default=get_env_value("LLM_MODEL", "mistral-nemo:latest"), - help="LLM model name (default: from env or mistral-nemo:latest)", - ) - - # Embedding model configuration - parser.add_argument( - "--embedding-binding-host", - default=get_env_value("EMBEDDING_BINDING_HOST", None), - help="Embedding server host URL. If not provided, defaults based on embedding-binding:\n" - + "- ollama: http://localhost:11434\n" - + "- lollms: http://localhost:9600\n" - + "- openai: https://api.openai.com/v1", - ) - - default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") - parser.add_argument( - "--embedding-binding-api-key", - default=default_embedding_api_key, - help="embedding server API key (default: from env or empty string)", - ) - - parser.add_argument( - "--embedding-model", - default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"), - help="Embedding model name (default: from env or bge-m3:latest)", - ) - - parser.add_argument( - "--chunk_size", - default=get_env_value("CHUNK_SIZE", 1200), - help="chunk chunk size default 1200", - ) - - parser.add_argument( - "--chunk_overlap_size", - default=get_env_value("CHUNK_OVERLAP_SIZE", 100), - help="chunk overlap size default 100", - ) - - def timeout_type(value): - if value is None or value == "None": - return None - return int(value) - - parser.add_argument( - "--timeout", - default=get_env_value("TIMEOUT", None, timeout_type), - type=timeout_type, - help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", - ) - - # RAG configuration - parser.add_argument( - "--max-async", - type=int, - default=get_env_value("MAX_ASYNC", 4, int), - help="Maximum async operations (default: from env or 4)", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=get_env_value("MAX_TOKENS", 32768, int), - help="Maximum token size (default: from env or 32768)", - ) - parser.add_argument( - "--embedding-dim", - type=int, - default=get_env_value("EMBEDDING_DIM", 1024, int), - help="Embedding dimensions (default: from env or 1024)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=get_env_value("MAX_EMBED_TOKENS", 8192, int), - help="Maximum embedding token size (default: from env or 8192)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default=get_env_value("LOG_LEVEL", "INFO"), - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: from env or INFO)", - ) - - parser.add_argument( - "--key", - type=str, - default=get_env_value("LIGHTRAG_API_KEY", None), - help="API key for authentication. This protects lightrag server against unauthorized access", - ) - - # Optional https parameters - parser.add_argument( - "--ssl", - action="store_true", - default=get_env_value("SSL", False, bool), - help="Enable HTTPS (default: from env or False)", - ) - parser.add_argument( - "--ssl-certfile", - default=get_env_value("SSL_CERTFILE", None), - help="Path to SSL certificate file (required if --ssl is enabled)", - ) - parser.add_argument( - "--ssl-keyfile", - default=get_env_value("SSL_KEYFILE", None), - help="Path to SSL private key file (required if --ssl is enabled)", - ) - parser.add_argument( - "--auto-scan-at-startup", - action="store_true", - default=False, - help="Enable automatic scanning when the program starts", - ) - - parser.add_argument( - "--history-turns", - type=int, - default=get_env_value("HISTORY_TURNS", 3, int), - help="Number of conversation history turns to include (default: from env or 3)", - ) - - # Search parameters - parser.add_argument( - "--top-k", - type=int, - default=get_env_value("TOP_K", 60, int), - help="Number of most similar results to return (default: from env or 60)", - ) - parser.add_argument( - "--cosine-threshold", - type=float, - default=get_env_value("COSINE_THRESHOLD", 0.2, float), - help="Cosine similarity threshold (default: from env or 0.4)", - ) - - # Ollama model name - parser.add_argument( - "--simulated-model-name", - type=str, - default=get_env_value( - "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL - ), - help="Number of conversation history turns to include (default: from env or 3)", - ) - - # Namespace - parser.add_argument( - "--namespace-prefix", - type=str, - default=get_env_value("NAMESPACE_PREFIX", ""), - help="Prefix of the namespace", - ) - - parser.add_argument( - "--verbose", - type=bool, - default=get_env_value("VERBOSE", False, bool), - help="Verbose debug output(default: from env or false)", - ) - - args = parser.parse_args() - - # conver relative path to absolute path - args.working_dir = os.path.abspath(args.working_dir) - args.input_dir = os.path.abspath(args.input_dir) - - ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name - - return args - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__( - self, - input_dir: str, - supported_extensions: tuple = ( - ".txt", - ".md", - ".pdf", - ".docx", - ".pptx", - ".xlsx", - ".rtf", # Rich Text Format - ".odt", # OpenDocument Text - ".tex", # LaTeX - ".epub", # Electronic Publication - ".html", # HyperText Markup Language - ".htm", # HyperText Markup Language - ".csv", # Comma-Separated Values - ".json", # JavaScript Object Notation - ".xml", # eXtensible Markup Language - ".yaml", # YAML Ain't Markup Language - ".yml", # YAML - ".log", # Log files - ".conf", # Configuration files - ".ini", # Initialization files - ".properties", # Java properties files - ".sql", # SQL scripts - ".bat", # Batch files - ".sh", # Shell scripts - ".c", # C source code - ".cpp", # C++ source code - ".py", # Python source code - ".java", # Java source code - ".js", # JavaScript source code - ".ts", # TypeScript source code - ".swift", # Swift source code - ".go", # Go source code - ".rb", # Ruby source code - ".php", # PHP source code - ".css", # Cascading Style Sheets - ".scss", # Sassy CSS - ".less", # LESS CSS - ), - ): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory_for_new_files(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - logger.info(f"Scanning for {ext} files in {self.input_dir}") - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -class QueryRequest(BaseModel): - query: str = Field( - min_length=1, - description="The query text", - ) - - mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( - default="hybrid", - description="Query mode", - ) - - only_need_context: Optional[bool] = Field( - default=None, - description="If True, only returns the retrieved context without generating a response.", - ) - - only_need_prompt: Optional[bool] = Field( - default=None, - description="If True, only returns the generated prompt without producing a response.", - ) - - response_type: Optional[str] = Field( - min_length=1, - default=None, - description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", - ) - - top_k: Optional[int] = Field( - ge=1, - default=None, - description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", - ) - - max_token_for_text_unit: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allowed for each retrieved text chunk.", - ) - - max_token_for_global_context: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", - ) - - max_token_for_local_context: Optional[int] = Field( - gt=1, - default=None, - description="Maximum number of tokens allocated for entity descriptions in local retrieval.", - ) - - hl_keywords: Optional[List[str]] = Field( - default=None, - description="List of high-level keywords to prioritize in retrieval.", - ) - - ll_keywords: Optional[List[str]] = Field( - default=None, - description="List of low-level keywords to refine retrieval focus.", - ) - - conversation_history: Optional[List[dict[str, Any]]] = Field( - default=None, - description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", - ) - - history_turns: Optional[int] = Field( - ge=0, - default=None, - description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", - ) - - @field_validator("query", mode="after") - @classmethod - def query_strip_after(cls, query: str) -> str: - return query.strip() - - @field_validator("hl_keywords", mode="after") - @classmethod - def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: - if hl_keywords is None: - return None - return [keyword.strip() for keyword in hl_keywords] - - @field_validator("ll_keywords", mode="after") - @classmethod - def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: - if ll_keywords is None: - return None - return [keyword.strip() for keyword in ll_keywords] - - @field_validator("conversation_history", mode="after") - @classmethod - def conversation_history_role_check( - cls, conversation_history: List[dict[str, Any]] | None - ) -> List[dict[str, Any]] | None: - if conversation_history is None: - return None - for msg in conversation_history: - if "role" not in msg or msg["role"] not in {"user", "assistant"}: - raise ValueError( - "Each message must have a 'role' key with value 'user' or 'assistant'." - ) - return conversation_history - - def to_query_params(self, is_stream: bool) -> QueryParam: - """Converts a QueryRequest instance into a QueryParam instance.""" - # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically - request_data = self.model_dump(exclude_none=True, exclude={"query"}) - - # Ensure `mode` and `stream` are set explicitly - param = QueryParam(**request_data) - param.stream = is_stream - return param - - -class QueryResponse(BaseModel): - response: str = Field( - description="The generated response", - ) - - -class InsertTextRequest(BaseModel): - text: str = Field( - min_length=1, - description="The text to insert", - ) - - @field_validator("text", mode="after") - @classmethod - def strip_after(cls, text: str) -> str: - return text.strip() - - -class InsertTextsRequest(BaseModel): - texts: list[str] = Field( - min_length=1, - description="The texts to insert", - ) - - @field_validator("texts", mode="after") - @classmethod - def strip_after(cls, texts: list[str]) -> list[str]: - return [text.strip() for text in texts] - - -class InsertResponse(BaseModel): - status: str = Field(description="Status of the operation") - message: str = Field(description="Message describing the operation result") - - -class DocStatusResponse(BaseModel): - @staticmethod - def format_datetime(dt: Any) -> Optional[str]: - """Format datetime to ISO string - - Args: - dt: Datetime object or string - - Returns: - Formatted datetime string or None - """ - if dt is None: - return None - if isinstance(dt, str): - return dt - return dt.isoformat() - - """Response model for document status - - Attributes: - id: Document identifier - content_summary: Summary of document content - content_length: Length of document content - status: Current processing status - created_at: Creation timestamp (ISO format string) - updated_at: Last update timestamp (ISO format string) - chunks_count: Number of chunks (optional) - error: Error message if any (optional) - metadata: Additional metadata (optional) - """ - - id: str - content_summary: str - content_length: int - status: DocStatus - created_at: str - updated_at: str - chunks_count: Optional[int] = None - error: Optional[str] = None - metadata: Optional[dict[str, Any]] = None - - -class DocsStatusesResponse(BaseModel): - statuses: Dict[DocStatus, List[DocStatusResponse]] = {} - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth( - api_key_header_value: Optional[str] = Security(api_key_header), - ): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -# Global configuration -global_top_k = 60 # default value -temp_prefix = "__tmp_" # prefix for temporary files - - def create_app(args): + # Set global top_k + global global_top_k + global_top_k = args.top_k # save top_k from args + # Initialize verbose debug setting from lightrag.utils import set_verbose_debug set_verbose_debug(args.verbose) - global global_top_k - global_top_k = args.top_k # save top_k from args - # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", @@ -945,7 +139,9 @@ def create_app(args): scan_progress["indexed_count"] = 0 scan_progress["progress"] = 0 # Create background task - task = asyncio.create_task(run_scanning_process()) + task = asyncio.create_task( + run_scanning_process(rag, doc_manager) + ) app.state.background_tasks.add(task) task.add_done_callback(app.state.background_tasks.discard) ASCIIColors.info( @@ -953,7 +149,7 @@ def create_app(args): ) else: ASCIIColors.info( - "Skip document scanning(anohter scanning is active)" + "Skip document scanning(another scanning is active)" ) yield @@ -1161,656 +357,15 @@ def create_app(args): auto_manage_storages_states=False, ) - async def pipeline_enqueue_file(file_path: Path) -> bool: - """Add a file to the queue for processing - - Args: - file_path: Path to the saved file - Returns: - bool: True if the file was successfully enqueued, False otherwise - """ - try: - content = "" - ext = file_path.suffix.lower() - - file = None - async with aiofiles.open(file_path, "rb") as f: - file = await f.read() - - # Process based on file type - match ext: - case ( - ".txt" - | ".md" - | ".html" - | ".htm" - | ".tex" - | ".json" - | ".xml" - | ".yaml" - | ".yml" - | ".rtf" - | ".odt" - | ".epub" - | ".csv" - | ".log" - | ".conf" - | ".ini" - | ".properties" - | ".sql" - | ".bat" - | ".sh" - | ".c" - | ".cpp" - | ".py" - | ".java" - | ".js" - | ".ts" - | ".swift" - | ".go" - | ".rb" - | ".php" - | ".css" - | ".scss" - | ".less" - ): - content = file.decode("utf-8") - - case ".pdf": - if not pm.is_installed("pypdf2"): - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO - - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" - case ".docx": - if not pm.is_installed("docx"): - pm.install("docx") - from docx import Document - from io import BytesIO - - docx_file = BytesIO(file) - doc = Document(docx_file) - content = "\n".join( - [paragraph.text for paragraph in doc.paragraphs] - ) - case ".pptx": - if not pm.is_installed("pptx"): - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO - - pptx_file = BytesIO(file) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" - case ".xlsx": - if not pm.is_installed("openpyxl"): - pm.install("openpyxl") - from openpyxl import load_workbook # type: ignore - from io import BytesIO - - xlsx_file = BytesIO(file) - wb = load_workbook(xlsx_file) - for sheet in wb: - content += f"Sheet: {sheet.title}\n" - for row in sheet.iter_rows(values_only=True): - content += ( - "\t".join( - str(cell) if cell is not None else "" - for cell in row - ) - + "\n" - ) - content += "\n" - case _: - logging.error( - f"Unsupported file type: {file_path.name} (extension {ext})" - ) - return False - - # Insert into the RAG queue - if content: - await rag.apipeline_enqueue_documents(content) - logging.info( - f"Successfully fetched and enqueued file: {file_path.name}" - ) - return True - else: - logging.error( - f"No content could be extracted from file: {file_path.name}" - ) - - except Exception as e: - logging.error( - f"Error processing or enqueueing file {file_path.name}: {str(e)}" - ) - logging.error(traceback.format_exc()) - finally: - if file_path.name.startswith(temp_prefix): - # Clean up the temporary file after indexing - try: - file_path.unlink() - except Exception as e: - logging.error(f"Error deleting file {file_path}: {str(e)}") - return False - - async def pipeline_index_file(file_path: Path): - """Index a file - - Args: - file_path: Path to the saved file - """ - try: - if await pipeline_enqueue_file(file_path): - await rag.apipeline_process_enqueue_documents() - - except Exception as e: - logging.error(f"Error indexing file {file_path.name}: {str(e)}") - logging.error(traceback.format_exc()) - - async def pipeline_index_files(file_paths: List[Path]): - """Index multiple files concurrently - - Args: - file_paths: Paths to the files to index - """ - if not file_paths: - return - try: - enqueued = False - - if len(file_paths) == 1: - enqueued = await pipeline_enqueue_file(file_paths[0]) - else: - tasks = [pipeline_enqueue_file(path) for path in file_paths] - enqueued = any(await asyncio.gather(*tasks)) - - if enqueued: - await rag.apipeline_process_enqueue_documents() - except Exception as e: - logging.error(f"Error indexing files: {str(e)}") - logging.error(traceback.format_exc()) - - async def pipeline_index_texts(texts: List[str]): - """Index a list of texts - - Args: - texts: The texts to index - """ - if not texts: - return - await rag.apipeline_enqueue_documents(texts) - await rag.apipeline_process_enqueue_documents() - - async def save_temp_file(file: UploadFile = File(...)) -> Path: - """Save the uploaded file to a temporary location - - Args: - file: The uploaded file - - Returns: - Path: The path to the saved file - """ - # Generate unique filename to avoid conflicts - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" - - # Create a temporary file to save the uploaded content - temp_path = doc_manager.input_dir / "temp" / unique_filename - temp_path.parent.mkdir(exist_ok=True) - - # Save the file - with open(temp_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - return temp_path - - async def run_scanning_process(): - """Background task to scan and index documents""" - global scan_progress - - try: - new_files = doc_manager.scan_directory_for_new_files() - scan_progress["total_files"] = len(new_files) - - logger.info(f"Found {len(new_files)} new files to index.") - for file_path in new_files: - try: - with progress_lock: - scan_progress["current_file"] = os.path.basename(file_path) - - await pipeline_index_file(file_path) - - with progress_lock: - scan_progress["indexed_count"] += 1 - scan_progress["progress"] = ( - scan_progress["indexed_count"] - / scan_progress["total_files"] - ) * 100 - - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - except Exception as e: - logging.error(f"Error during scanning process: {str(e)}") - finally: - with progress_lock: - scan_progress["is_scanning"] = False - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(background_tasks: BackgroundTasks): - """Trigger the scanning process""" - global scan_progress - - with progress_lock: - if scan_progress["is_scanning"]: - return {"status": "already_scanning"} - - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 - - # Start the scanning process in the background - background_tasks.add_task(run_scanning_process) - - return {"status": "scanning_started"} - - @app.get("/documents/scan-progress") - async def get_scan_progress(): - """Get the current scanning progress""" - with progress_lock: - return scan_progress - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) - ): - """ - Endpoint for uploading a file to the input directory and indexing it. - - This API endpoint accepts a file through an HTTP POST request, checks if the - uploaded file is of a supported type, saves it in the specified input directory, - indexes it for retrieval, and returns a success status with relevant details. - - Parameters: - background_tasks: FastAPI BackgroundTasks for async processing - file (UploadFile): The file to be uploaded. It must have an allowed extension as per - `doc_manager.supported_extensions`. - - Returns: - dict: A dictionary containing the upload status ("success"), - a message detailing the operation result, and - the total number of indexed documents. - - Raises: - HTTPException: If the file type is not supported, it raises a 400 Bad Request error. - If any other exception occurs during the file handling or indexing, - it raises a 500 Internal Server Error with details about the exception. - """ - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Add to background tasks - background_tasks.add_task(pipeline_index_file, file_path) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text( - request: InsertTextRequest, background_tasks: BackgroundTasks - ): - """ - Insert text into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextRequest): The request body containing the text to be inserted. - background_tasks: FastAPI BackgroundTasks for async processing - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - background_tasks.add_task(pipeline_index_texts, [request.text]) - return InsertResponse( - status="success", - message="Text successfully received. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/texts", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_texts( - request: InsertTextsRequest, background_tasks: BackgroundTasks - ): - """ - Insert texts into the Retrieval-Augmented Generation (RAG) system. - - This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses. - - Args: - request (InsertTextsRequest): The request body containing the text to be inserted. - background_tasks: FastAPI BackgroundTasks for async processing - - Returns: - InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted. - """ - try: - background_tasks.add_task(pipeline_index_texts, request.texts) - return InsertResponse( - status="success", - message="Text successfully received. Processing will continue in background.", - ) - except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file( - background_tasks: BackgroundTasks, file: UploadFile = File(...) - ): - """Insert a file directly into the RAG system - - Args: - background_tasks: FastAPI BackgroundTasks for async processing - file: Uploaded file - - Returns: - InsertResponse: Status of the insertion operation - - Raises: - HTTPException: For unsupported file types or processing errors - """ - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - # Create a temporary file to save the uploaded content - temp_path = save_temp_file(file) - - # Add to background tasks - background_tasks.add_task(pipeline_index_file, temp_path) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' saved successfully. Processing will continue in background.", - ) - - except Exception as e: - logging.error(f"Error /documents/file: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file_batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch( - background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) - ): - """Process multiple files in batch mode - - Args: - background_tasks: FastAPI BackgroundTasks for async processing - files: List of files to process - - Returns: - InsertResponse: Status of the batch insertion operation - - Raises: - HTTPException: For processing errors - """ - try: - inserted_count = 0 - failed_files = [] - temp_files = [] - - for file in files: - if doc_manager.is_supported_file(file.filename): - # Create a temporary file to save the uploaded content - temp_files.append(save_temp_file(file)) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - - if temp_files: - background_tasks.add_task(pipeline_index_files, temp_files) - - # Prepare status message - if inserted_count == len(files): - status = "success" - status_message = f"Successfully inserted all {inserted_count} documents" - elif inserted_count > 0: - status = "partial_success" - status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - else: - status = "failure" - status_message = "No documents were successfully inserted" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse(status=status, message=status_message) - - except Exception as e: - logging.error(f"Error /documents/batch: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - """ - Clear all documents from the LightRAG system. - - This endpoint deletes all text chunks, entities vector database, and relationships vector database, - effectively clearing all documents from the LightRAG system. - - Returns: - InsertResponse: A response object containing the status, message, and the new document count (0 in this case). - """ - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", message="All documents cleared successfully" - ) - except Exception as e: - logging.error(f"Error DELETE /documents: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - """ - Handle a POST request at the /query endpoint to process user queries using RAG capabilities. - - Parameters: - request (QueryRequest): The request object containing the query parameters. - Returns: - QueryResponse: A Pydantic model containing the result of the query processing. - If a string is returned (e.g., cache hit), it's directly returned. - Otherwise, an async generator may be used to build the response. - - Raises: - HTTPException: Raised when an error occurs during the request handling process, - with status code 500 and detail containing the exception message. - """ - try: - response = await rag.aquery( - request.query, param=request.to_query_params(False) - ) - - # If response is a string (e.g. cache hit), return directly - if isinstance(response, str): - return QueryResponse(response=response) - - if isinstance(response, dict): - result = json.dumps(response, indent=2) - return QueryResponse(response=result) - else: - return QueryResponse(response=str(response)) - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - """ - This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. - - Args: - request (QueryRequest): The request object containing the query parameters. - optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None. - - Returns: - StreamingResponse: A streaming response containing the RAG query results. - """ - try: - response = await rag.aquery( - request.query, param=request.to_query_params(True) - ) - - from fastapi.responses import StreamingResponse - - async def stream_generator(): - if isinstance(response, str): - # If it's a string, send it all at once - yield f"{json.dumps({'response': response})}\n" - else: - # If it's an async generator, send chunks one by one - try: - async for chunk in response: - if chunk: # Only send non-empty content - yield f"{json.dumps({'response': chunk})}\n" - except Exception as e: - logging.error(f"Streaming error: {str(e)}") - yield f"{json.dumps({'error': str(e)})}\n" - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx - }, - ) - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - # query all graph labels - @app.get("/graph/label/list") - async def get_graph_labels(): - return await rag.get_graph_labels() - - # query all graph - @app.get("/graphs") - async def get_knowledge_graph(label: str): - return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + # Add routes + app.include_router(create_document_routes(rag, doc_manager, api_key)) + app.include_router(create_query_routes(rag, api_key, args.top_k)) + app.include_router(create_graph_routes(rag, api_key)) # Add Ollama API routes ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") - @app.get("/documents", dependencies=[Depends(optional_api_key)]) - async def documents() -> DocsStatusesResponse: - """ - Get documents statuses - Returns: - DocsStatusesResponse: A response object containing a dictionary where keys are DocStatus - and values are lists of DocStatusResponse objects representing documents in each status category. - """ - try: - statuses = ( - DocStatus.PENDING, - DocStatus.PROCESSING, - DocStatus.PROCESSED, - DocStatus.FAILED, - ) - - tasks = [rag.get_docs_by_status(status) for status in statuses] - results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks) - - response = DocsStatusesResponse() - - for idx, result in enumerate(results): - status = statuses[idx] - for doc_id, doc_status in result.items(): - if status not in response.statuses: - response.statuses[status] = [] - response.statuses[status].append( - DocStatusResponse( - id=doc_id, - content_summary=doc_status.content_summary, - content_length=doc_status.content_length, - status=doc_status.status, - created_at=DocStatusResponse.format_datetime( - doc_status.created_at - ), - updated_at=DocStatusResponse.format_datetime( - doc_status.updated_at - ), - chunks_count=doc_status.chunks_count, - error=doc_status.error, - metadata=doc_status.metadata, - ) - ) - return response - except Exception as e: - logging.error(f"Error GET /documents: {str(e)}") - logging.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" @@ -1838,7 +393,15 @@ def create_app(args): # Webui mount webui/index.html static_dir = Path(__file__).parent / "webui" static_dir.mkdir(exist_ok=True) - app.mount("/webui", StaticFiles(directory=static_dir, html=True), name="webui") + app.mount( + "/webui", + StaticFiles(directory=static_dir, html=True, check_dir=True), + name="webui", + ) + + @app.get("/webui/") + async def webui_root(): + return FileResponse(static_dir / "index.html") return app diff --git a/lightrag/api/routers/__init__.py b/lightrag/api/routers/__init__.py new file mode 100644 index 00000000..b71f204e --- /dev/null +++ b/lightrag/api/routers/__init__.py @@ -0,0 +1,10 @@ +""" +This module contains all the routers for the LightRAG API. +""" + +from .document_routes import router as document_router +from .query_routes import router as query_router +from .graph_routes import router as graph_router +from .ollama_api import OllamaAPI + +__all__ = ["document_router", "query_router", "graph_router", "OllamaAPI"] diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py new file mode 100644 index 00000000..25ca24e4 --- /dev/null +++ b/lightrag/api/routers/document_routes.py @@ -0,0 +1,770 @@ +""" +This module contains all document-related routes for the LightRAG API. +""" + +import asyncio +import logging +import os +import aiofiles +import shutil +import traceback +import pipmaster as pm +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any + +from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile +from pydantic import BaseModel, Field, field_validator + +from lightrag import LightRAG +from lightrag.base import DocProcessingStatus, DocStatus +from ..utils_api import get_api_key_dependency + + +router = APIRouter(prefix="/documents", tags=["documents"]) + +# Global progress tracker +scan_progress: Dict = { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, +} + +# Lock for thread-safe operations +progress_lock = asyncio.Lock() + +# Temporary file prefix +temp_prefix = "__tmp__" + + +class InsertTextRequest(BaseModel): + text: str = Field( + min_length=1, + description="The text to insert", + ) + + @field_validator("text", mode="after") + @classmethod + def strip_after(cls, text: str) -> str: + return text.strip() + + +class InsertTextsRequest(BaseModel): + texts: list[str] = Field( + min_length=1, + description="The texts to insert", + ) + + @field_validator("texts", mode="after") + @classmethod + def strip_after(cls, texts: list[str]) -> list[str]: + return [text.strip() for text in texts] + + +class InsertResponse(BaseModel): + status: str = Field(description="Status of the operation") + message: str = Field(description="Message describing the operation result") + + +class DocStatusResponse(BaseModel): + @staticmethod + def format_datetime(dt: Any) -> Optional[str]: + if dt is None: + return None + if isinstance(dt, str): + return dt + return dt.isoformat() + + """Response model for document status + + Attributes: + id: Document identifier + content_summary: Summary of document content + content_length: Length of document content + status: Current processing status + created_at: Creation timestamp (ISO format string) + updated_at: Last update timestamp (ISO format string) + chunks_count: Number of chunks (optional) + error: Error message if any (optional) + metadata: Additional metadata (optional) + """ + + id: str + content_summary: str + content_length: int + status: DocStatus + created_at: str + updated_at: str + chunks_count: Optional[int] = None + error: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + + +class DocsStatusesResponse(BaseModel): + statuses: Dict[DocStatus, List[DocStatusResponse]] = {} + + +class DocumentManager: + def __init__( + self, + input_dir: str, + supported_extensions: tuple = ( + ".txt", + ".md", + ".pdf", + ".docx", + ".pptx", + ".xlsx", + ".rtf", # Rich Text Format + ".odt", # OpenDocument Text + ".tex", # LaTeX + ".epub", # Electronic Publication + ".html", # HyperText Markup Language + ".htm", # HyperText Markup Language + ".csv", # Comma-Separated Values + ".json", # JavaScript Object Notation + ".xml", # eXtensible Markup Language + ".yaml", # YAML Ain't Markup Language + ".yml", # YAML + ".log", # Log files + ".conf", # Configuration files + ".ini", # Initialization files + ".properties", # Java properties files + ".sql", # SQL scripts + ".bat", # Batch files + ".sh", # Shell scripts + ".c", # C source code + ".cpp", # C++ source code + ".py", # Python source code + ".java", # Java source code + ".js", # JavaScript source code + ".ts", # TypeScript source code + ".swift", # Swift source code + ".go", # Go source code + ".rb", # Ruby source code + ".php", # PHP source code + ".css", # Cascading Style Sheets + ".scss", # Sassy CSS + ".less", # LESS CSS + ), + ): + self.input_dir = Path(input_dir) + self.supported_extensions = supported_extensions + self.indexed_files = set() + + # Create input directory if it doesn't exist + self.input_dir.mkdir(parents=True, exist_ok=True) + + def scan_directory_for_new_files(self) -> List[Path]: + """Scan input directory for new files""" + new_files = [] + for ext in self.supported_extensions: + logging.info(f"Scanning for {ext} files in {self.input_dir}") + for file_path in self.input_dir.rglob(f"*{ext}"): + if file_path not in self.indexed_files: + new_files.append(file_path) + return new_files + + # def scan_directory(self) -> List[Path]: + # new_files = [] + # for ext in self.supported_extensions: + # for file_path in self.input_dir.rglob(f"*{ext}"): + # new_files.append(file_path) + # return new_files + + def mark_as_indexed(self, file_path: Path): + self.indexed_files.add(file_path) + + def is_supported_file(self, filename: str) -> bool: + return any(filename.lower().endswith(ext) for ext in self.supported_extensions) + + +async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: + """Add a file to the queue for processing + + Args: + rag: LightRAG instance + file_path: Path to the saved file + Returns: + bool: True if the file was successfully enqueued, False otherwise + """ + + try: + content = "" + ext = file_path.suffix.lower() + + file = None + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() + + # Process based on file type + match ext: + case ( + ".txt" + | ".md" + | ".html" + | ".htm" + | ".tex" + | ".json" + | ".xml" + | ".yaml" + | ".yml" + | ".rtf" + | ".odt" + | ".epub" + | ".csv" + | ".log" + | ".conf" + | ".ini" + | ".properties" + | ".sql" + | ".bat" + | ".sh" + | ".c" + | ".cpp" + | ".py" + | ".java" + | ".js" + | ".ts" + | ".swift" + | ".go" + | ".rb" + | ".php" + | ".css" + | ".scss" + | ".less" + ): + content = file.decode("utf-8") + case ".pdf": + if not pm.is_installed("pypdf2"): + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + case ".docx": + if not pm.is_installed("docx"): + pm.install("docx") + from docx import Document + from io import BytesIO + + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) + case ".pptx": + if not pm.is_installed("pptx"): + pm.install("pptx") + from pptx import Presentation + from io import BytesIO + + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + case ".xlsx": + if not pm.is_installed("openpyxl"): + pm.install("openpyxl") + from openpyxl import load_workbook + from io import BytesIO + + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += ( + "\t".join( + str(cell) if cell is not None else "" for cell in row + ) + + "\n" + ) + content += "\n" + case _: + logging.error( + f"Unsupported file type: {file_path.name} (extension {ext})" + ) + return False + + # Insert into the RAG queue + if content: + await rag.apipeline_enqueue_documents(content) + logging.info(f"Successfully fetched and enqueued file: {file_path.name}") + return True + else: + logging.error(f"No content could be extracted from file: {file_path.name}") + + except Exception as e: + logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + finally: + if file_path.name.startswith(temp_prefix): + try: + file_path.unlink() + except Exception as e: + logging.error(f"Error deleting file {file_path}: {str(e)}") + return False + + +async def pipeline_index_file(rag: LightRAG, file_path: Path): + """Index a file + + Args: + rag: LightRAG instance + file_path: Path to the saved file + """ + try: + if await pipeline_enqueue_file(rag, file_path): + await rag.apipeline_process_enqueue_documents() + + except Exception as e: + logging.error(f"Error indexing file {file_path.name}: {str(e)}") + logging.error(traceback.format_exc()) + + +async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): + """Index multiple files concurrently + + Args: + rag: LightRAG instance + file_paths: Paths to the files to index + """ + if not file_paths: + return + try: + enqueued = False + + if len(file_paths) == 1: + enqueued = await pipeline_enqueue_file(rag, file_paths[0]) + else: + tasks = [pipeline_enqueue_file(rag, path) for path in file_paths] + enqueued = any(await asyncio.gather(*tasks)) + + if enqueued: + await rag.apipeline_process_enqueue_documents() + except Exception as e: + logging.error(f"Error indexing files: {str(e)}") + logging.error(traceback.format_exc()) + + +async def pipeline_index_texts(rag: LightRAG, texts: List[str]): + """Index a list of texts + + Args: + rag: LightRAG instance + texts: The texts to index + """ + if not texts: + return + await rag.apipeline_enqueue_documents(texts) + await rag.apipeline_process_enqueue_documents() + + +async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: + """Save the uploaded file to a temporary location + + Args: + file: The uploaded file + + Returns: + Path: The path to the saved file + """ + # Generate unique filename to avoid conflicts + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" + + # Create a temporary file to save the uploaded content + temp_path = input_dir / "temp" / unique_filename + temp_path.parent.mkdir(exist_ok=True) + + # Save the file + with open(temp_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + return temp_path + + +async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): + """Background task to scan and index documents""" + try: + new_files = doc_manager.scan_directory_for_new_files() + scan_progress["total_files"] = len(new_files) + + logging.info(f"Found {len(new_files)} new files to index.") + for file_path in new_files: + try: + async with progress_lock: + scan_progress["current_file"] = os.path.basename(file_path) + + await pipeline_index_file(rag, file_path) + + async with progress_lock: + scan_progress["indexed_count"] += 1 + scan_progress["progress"] = ( + scan_progress["indexed_count"] / scan_progress["total_files"] + ) * 100 + + except Exception as e: + logging.error(f"Error indexing file {file_path}: {str(e)}") + + except Exception as e: + logging.error(f"Error during scanning process: {str(e)}") + finally: + async with progress_lock: + scan_progress["is_scanning"] = False + + +def create_document_routes( + rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None +): + optional_api_key = get_api_key_dependency(api_key) + + @router.post("/scan", dependencies=[Depends(optional_api_key)]) + async def scan_for_new_documents(background_tasks: BackgroundTasks): + """ + Trigger the scanning process for new documents. + + This endpoint initiates a background task that scans the input directory for new documents + and processes them. If a scanning process is already running, it returns a status indicating + that fact. + + Returns: + dict: A dictionary containing the scanning status + """ + async with progress_lock: + if scan_progress["is_scanning"]: + return {"status": "already_scanning"} + + scan_progress["is_scanning"] = True + scan_progress["indexed_count"] = 0 + scan_progress["progress"] = 0 + + # Start the scanning process in the background + background_tasks.add_task(run_scanning_process, rag, doc_manager) + return {"status": "scanning_started"} + + @router.get("/scan-progress") + async def get_scan_progress(): + """ + Get the current progress of the document scanning process. + + Returns: + dict: A dictionary containing the current scanning progress information including: + - is_scanning: Whether a scan is currently in progress + - current_file: The file currently being processed + - indexed_count: Number of files indexed so far + - total_files: Total number of files to process + - progress: Percentage of completion + """ + async with progress_lock: + return scan_progress + + @router.post("/upload", dependencies=[Depends(optional_api_key)]) + async def upload_to_input_dir( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): + """ + Upload a file to the input directory and index it. + + This API endpoint accepts a file through an HTTP POST request, checks if the + uploaded file is of a supported type, saves it in the specified input directory, + indexes it for retrieval, and returns a success status with relevant details. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file (UploadFile): The file to be uploaded. It must have an allowed extension. + + Returns: + InsertResponse: A response object containing the upload status and a message. + + Raises: + HTTPException: If the file type is not supported (400) or other errors occur (500). + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + file_path = doc_manager.input_dir / file.filename + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + # Add to background tasks + background_tasks.add_task(pipeline_index_file, rag, file_path) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post( + "/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + ) + async def insert_text( + request: InsertTextRequest, background_tasks: BackgroundTasks + ): + """ + Insert text into the RAG system. + + This endpoint allows you to insert text data into the RAG system for later retrieval + and use in generating responses. + + Args: + request (InsertTextRequest): The request body containing the text to be inserted. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If an error occurs during text processing (500). + """ + try: + background_tasks.add_task(pipeline_index_texts, rag, [request.text]) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post( + "/texts", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_texts( + request: InsertTextsRequest, background_tasks: BackgroundTasks + ): + """ + Insert multiple texts into the RAG system. + + This endpoint allows you to insert multiple text entries into the RAG system + in a single request. + + Args: + request (InsertTextsRequest): The request body containing the list of texts. + background_tasks: FastAPI BackgroundTasks for async processing + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If an error occurs during text processing (500). + """ + try: + background_tasks.add_task(pipeline_index_texts, rag, request.texts) + return InsertResponse( + status="success", + message="Text successfully received. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/text: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post( + "/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + ) + async def insert_file( + background_tasks: BackgroundTasks, file: UploadFile = File(...) + ): + """ + Insert a file directly into the RAG system. + + This endpoint accepts a file upload and processes it for inclusion in the RAG system. + The file is saved temporarily and processed in the background. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + file (UploadFile): The file to be processed + + Returns: + InsertResponse: A response object containing the status of the operation. + + Raises: + HTTPException: If the file type is not supported (400) or other errors occur (500). + """ + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + temp_path = await save_temp_file(doc_manager.input_dir, file) + + # Add to background tasks + background_tasks.add_task(pipeline_index_file, rag, temp_path) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' saved successfully. Processing will continue in background.", + ) + except Exception as e: + logging.error(f"Error /documents/file: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post( + "/file_batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_batch( + background_tasks: BackgroundTasks, files: List[UploadFile] = File(...) + ): + """ + Process multiple files in batch mode. + + This endpoint allows uploading and processing multiple files simultaneously. + It handles partial successes and provides detailed feedback about failed files. + + Args: + background_tasks: FastAPI BackgroundTasks for async processing + files (List[UploadFile]): List of files to process + + Returns: + InsertResponse: A response object containing: + - status: "success", "partial_success", or "failure" + - message: Detailed information about the operation results + + Raises: + HTTPException: If an error occurs during processing (500). + """ + try: + inserted_count = 0 + failed_files = [] + temp_files = [] + + for file in files: + if doc_manager.is_supported_file(file.filename): + # Create a temporary file to save the uploaded content + temp_files.append(await save_temp_file(doc_manager.input_dir, file)) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + + if temp_files: + background_tasks.add_task(pipeline_index_files, rag, temp_files) + + # Prepare status message + if inserted_count == len(files): + status = "success" + status_message = f"Successfully inserted all {inserted_count} documents" + elif inserted_count > 0: + status = "partial_success" + status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + else: + status = "failure" + status_message = "No documents were successfully inserted" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse(status=status, message=status_message) + except Exception as e: + logging.error(f"Error /documents/batch: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.delete( + "", response_model=InsertResponse, dependencies=[Depends(optional_api_key)] + ) + async def clear_documents(): + """ + Clear all documents from the RAG system. + + This endpoint deletes all text chunks, entities vector database, and relationships + vector database, effectively clearing all documents from the RAG system. + + Returns: + InsertResponse: A response object containing the status and message. + + Raises: + HTTPException: If an error occurs during the clearing process (500). + """ + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse( + status="success", message="All documents cleared successfully" + ) + except Exception as e: + logging.error(f"Error DELETE /documents: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("", dependencies=[Depends(optional_api_key)]) + async def documents() -> DocsStatusesResponse: + """ + Get the status of all documents in the system. + + This endpoint retrieves the current status of all documents, grouped by their + processing status (PENDING, PROCESSING, PROCESSED, FAILED). + + Returns: + DocsStatusesResponse: A response object containing a dictionary where keys are + DocStatus values and values are lists of DocStatusResponse + objects representing documents in each status category. + + Raises: + HTTPException: If an error occurs while retrieving document statuses (500). + """ + try: + statuses = ( + DocStatus.PENDING, + DocStatus.PROCESSING, + DocStatus.PROCESSED, + DocStatus.FAILED, + ) + + tasks = [rag.get_docs_by_status(status) for status in statuses] + results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks) + + response = DocsStatusesResponse() + + for idx, result in enumerate(results): + status = statuses[idx] + for doc_id, doc_status in result.items(): + if status not in response.statuses: + response.statuses[status] = [] + response.statuses[status].append( + DocStatusResponse( + id=doc_id, + content_summary=doc_status.content_summary, + content_length=doc_status.content_length, + status=doc_status.status, + created_at=DocStatusResponse.format_datetime( + doc_status.created_at + ), + updated_at=DocStatusResponse.format_datetime( + doc_status.updated_at + ), + chunks_count=doc_status.chunks_count, + error=doc_status.error, + metadata=doc_status.metadata, + ) + ) + return response + except Exception as e: + logging.error(f"Error GET /documents: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py new file mode 100644 index 00000000..bfdb838c --- /dev/null +++ b/lightrag/api/routers/graph_routes.py @@ -0,0 +1,27 @@ +""" +This module contains all graph-related routes for the LightRAG API. +""" + +from typing import Optional + +from fastapi import APIRouter, Depends + +from ..utils_api import get_api_key_dependency + +router = APIRouter(tags=["graph"]) + + +def create_graph_routes(rag, api_key: Optional[str] = None): + optional_api_key = get_api_key_dependency(api_key) + + @router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) + async def get_graph_labels(): + """Get all graph labels""" + return await rag.get_graph_labels() + + @router.get("/graphs", dependencies=[Depends(optional_api_key)]) + async def get_knowledge_graph(label: str): + """Get knowledge graph for a specific label""" + return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + + return router diff --git a/lightrag/api/ollama_api.py b/lightrag/api/routers/ollama_api.py similarity index 97% rename from lightrag/api/ollama_api.py rename to lightrag/api/routers/ollama_api.py index 7d9fe3b9..9688d073 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -5,31 +5,13 @@ import logging import time import json import re -import os from enum import Enum from fastapi.responses import StreamingResponse import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam from lightrag.utils import encode_string_by_tiktoken -from dotenv import load_dotenv - - -# Load environment variables -load_dotenv(override=True) - - -class OllamaServerInfos: - # Constants for emulated Ollama model information - LIGHTRAG_NAME = "lightrag" - LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") - LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" - LIGHTRAG_SIZE = 7365960935 # it's a dummy value - LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" - LIGHTRAG_DIGEST = "sha256:lightrag" - - -ollama_server_infos = OllamaServerInfos() +from ..utils_api import ollama_server_infos # query mode according to query prefix (bypass is not LightRAG quer mode) @@ -144,7 +126,7 @@ class OllamaAPI: self.rag = rag self.ollama_server_infos = ollama_server_infos self.top_k = top_k - self.router = APIRouter() + self.router = APIRouter(tags=["ollama"]) self.setup_routes() def setup_routes(self): diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py new file mode 100644 index 00000000..b86c170e --- /dev/null +++ b/lightrag/api/routers/query_routes.py @@ -0,0 +1,229 @@ +""" +This module contains all query-related routes for the LightRAG API. +""" + +import json +import logging +from typing import Any, Dict, List, Literal, Optional + +from fastapi import APIRouter, Depends, HTTPException +from lightrag.base import QueryParam +from ..utils_api import get_api_key_dependency +from pydantic import BaseModel, Field, field_validator + +from ascii_colors import trace_exception + +router = APIRouter(tags=["query"]) + + +class QueryRequest(BaseModel): + query: str = Field( + min_length=1, + description="The query text", + ) + + mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( + default="hybrid", + description="Query mode", + ) + + only_need_context: Optional[bool] = Field( + default=None, + description="If True, only returns the retrieved context without generating a response.", + ) + + only_need_prompt: Optional[bool] = Field( + default=None, + description="If True, only returns the generated prompt without producing a response.", + ) + + response_type: Optional[str] = Field( + min_length=1, + default=None, + description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", + ) + + top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", + ) + + max_token_for_text_unit: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allowed for each retrieved text chunk.", + ) + + max_token_for_global_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", + ) + + max_token_for_local_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for entity descriptions in local retrieval.", + ) + + hl_keywords: Optional[List[str]] = Field( + default=None, + description="List of high-level keywords to prioritize in retrieval.", + ) + + ll_keywords: Optional[List[str]] = Field( + default=None, + description="List of low-level keywords to refine retrieval focus.", + ) + + conversation_history: Optional[List[Dict[str, Any]]] = Field( + default=None, + description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", + ) + + history_turns: Optional[int] = Field( + ge=0, + default=None, + description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", + ) + + @field_validator("query", mode="after") + @classmethod + def query_strip_after(cls, query: str) -> str: + return query.strip() + + @field_validator("hl_keywords", mode="after") + @classmethod + def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: + if hl_keywords is None: + return None + return [keyword.strip() for keyword in hl_keywords] + + @field_validator("ll_keywords", mode="after") + @classmethod + def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: + if ll_keywords is None: + return None + return [keyword.strip() for keyword in ll_keywords] + + @field_validator("conversation_history", mode="after") + @classmethod + def conversation_history_role_check( + cls, conversation_history: List[Dict[str, Any]] | None + ) -> List[Dict[str, Any]] | None: + if conversation_history is None: + return None + for msg in conversation_history: + if "role" not in msg or msg["role"] not in {"user", "assistant"}: + raise ValueError( + "Each message must have a 'role' key with value 'user' or 'assistant'." + ) + return conversation_history + + def to_query_params(self, is_stream: bool) -> "QueryParam": + """Converts a QueryRequest instance into a QueryParam instance.""" + # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically + request_data = self.model_dump(exclude_none=True, exclude={"query"}) + + # Ensure `mode` and `stream` are set explicitly + param = QueryParam(**request_data) + param.stream = is_stream + return param + + +class QueryResponse(BaseModel): + response: str = Field( + description="The generated response", + ) + + +def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): + optional_api_key = get_api_key_dependency(api_key) + + @router.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) + async def query_text(request: QueryRequest): + """ + Handle a POST request at the /query endpoint to process user queries using RAG capabilities. + + Parameters: + request (QueryRequest): The request object containing the query parameters. + Returns: + QueryResponse: A Pydantic model containing the result of the query processing. + If a string is returned (e.g., cache hit), it's directly returned. + Otherwise, an async generator may be used to build the response. + + Raises: + HTTPException: Raised when an error occurs during the request handling process, + with status code 500 and detail containing the exception message. + """ + try: + param = request.to_query_params(False) + if param.top_k is None: + param.top_k = top_k + response = await rag.aquery(request.query, param=param) + + # If response is a string (e.g. cache hit), return directly + if isinstance(response, str): + return QueryResponse(response=response) + + if isinstance(response, dict): + result = json.dumps(response, indent=2) + return QueryResponse(response=result) + else: + return QueryResponse(response=str(response)) + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/query/stream", dependencies=[Depends(optional_api_key)]) + async def query_text_stream(request: QueryRequest): + """ + This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. + + Args: + request (QueryRequest): The request object containing the query parameters. + optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None. + + Returns: + StreamingResponse: A streaming response containing the RAG query results. + """ + try: + param = request.to_query_params(True) + if param.top_k is None: + param.top_k = top_k + response = await rag.aquery(request.query, param=param) + + from fastapi.responses import StreamingResponse + + async def stream_generator(): + if isinstance(response, str): + # If it's a string, send it all at once + yield f"{json.dumps({'response': response})}\n" + else: + # If it's an async generator, send chunks one by one + try: + async for chunk in response: + if chunk: # Only send non-empty content + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logging.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx + }, + ) + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py new file mode 100644 index 00000000..a24e731e --- /dev/null +++ b/lightrag/api/utils_api.py @@ -0,0 +1,554 @@ +""" +Utility functions for the LightRAG API. +""" + +import os +import argparse +from typing import Optional +import sys +from ascii_colors import ASCIIColors +from lightrag.api import __api_version__ +from fastapi import HTTPException, Security +from dotenv import load_dotenv +from fastapi.security import APIKeyHeader +from starlette.status import HTTP_403_FORBIDDEN + +# Load environment variables +load_dotenv(override=True) + + +class OllamaServerInfos: + # Constants for emulated Ollama model information + LIGHTRAG_NAME = "lightrag" + LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") + LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" + LIGHTRAG_SIZE = 7365960935 # it's a dummy value + LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" + LIGHTRAG_DIGEST = "sha256:lightrag" + + +ollama_server_infos = OllamaServerInfos() + + +def get_api_key_dependency(api_key: Optional[str]): + """ + Create an API key dependency for route protection. + + Args: + api_key (Optional[str]): The API key to validate against. + If None, no authentication is required. + + Returns: + Callable: A dependency function that validates the API key. + """ + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth( + api_key_header_value: Optional[str] = Security(api_key_header), + ): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth + + +class DefaultRAGStorageConfig: + KV_STORAGE = "JsonKVStorage" + VECTOR_STORAGE = "NanoVectorDBStorage" + GRAPH_STORAGE = "NetworkXStorage" + DOC_STATUS_STORAGE = "JsonDocStatusStorage" + + +def get_default_host(binding_type: str) -> str: + default_hosts = { + "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), + "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), + "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), + "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), + } + return default_hosts.get( + binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") + ) # fallback to ollama if unknown + + +def get_env_value(env_key: str, default: any, value_type: type = str) -> any: + """ + Get value from environment variable with type conversion + + Args: + env_key (str): Environment variable key + default (any): Default value if env variable is not set + value_type (type): Type to convert the value to + + Returns: + any: Converted value from environment or default + """ + value = os.getenv(env_key) + if value is None: + return default + + if value_type is bool: + return value.lower() in ("true", "1", "yes", "t", "on") + try: + return value_type(value) + except ValueError: + return default + + +def parse_args() -> argparse.Namespace: + """ + Parse command line arguments with environment variable fallback + + Returns: + argparse.Namespace: Parsed arguments + """ + + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) + + parser.add_argument( + "--kv-storage", + default=get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ), + help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})", + ) + parser.add_argument( + "--doc-status-storage", + default=get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ), + help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", + ) + parser.add_argument( + "--graph-storage", + default=get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ), + help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", + ) + parser.add_argument( + "--vector-storage", + default=get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ), + help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", + ) + + # Bindings configuration + parser.add_argument( + "--llm-binding", + default=get_env_value("LLM_BINDING", "ollama"), + help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", + ) + parser.add_argument( + "--embedding-binding", + default=get_env_value("EMBEDDING_BINDING", "ollama"), + help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", + ) + + # Server configuration + parser.add_argument( + "--host", + default=get_env_value("HOST", "0.0.0.0"), + help="Server host (default: from env or 0.0.0.0)", + ) + parser.add_argument( + "--port", + type=int, + default=get_env_value("PORT", 9621, int), + help="Server port (default: from env or 9621)", + ) + + # Directory configuration + parser.add_argument( + "--working-dir", + default=get_env_value("WORKING_DIR", "./rag_storage"), + help="Working directory for RAG storage (default: from env or ./rag_storage)", + ) + parser.add_argument( + "--input-dir", + default=get_env_value("INPUT_DIR", "./inputs"), + help="Directory containing input documents (default: from env or ./inputs)", + ) + + # LLM Model configuration + parser.add_argument( + "--llm-binding-host", + default=get_env_value("LLM_BINDING_HOST", None), + help="LLM server host URL. If not provided, defaults based on llm-binding:\n" + + "- ollama: http://localhost:11434\n" + + "- lollms: http://localhost:9600\n" + + "- openai: https://api.openai.com/v1", + ) + + default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None) + + parser.add_argument( + "--llm-binding-api-key", + default=default_llm_api_key, + help="llm server API key (default: from env or empty string)", + ) + + parser.add_argument( + "--llm-model", + default=get_env_value("LLM_MODEL", "mistral-nemo:latest"), + help="LLM model name (default: from env or mistral-nemo:latest)", + ) + + # Embedding model configuration + parser.add_argument( + "--embedding-binding-host", + default=get_env_value("EMBEDDING_BINDING_HOST", None), + help="Embedding server host URL. If not provided, defaults based on embedding-binding:\n" + + "- ollama: http://localhost:11434\n" + + "- lollms: http://localhost:9600\n" + + "- openai: https://api.openai.com/v1", + ) + + default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") + parser.add_argument( + "--embedding-binding-api-key", + default=default_embedding_api_key, + help="embedding server API key (default: from env or empty string)", + ) + + parser.add_argument( + "--embedding-model", + default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"), + help="Embedding model name (default: from env or bge-m3:latest)", + ) + + parser.add_argument( + "--chunk_size", + default=get_env_value("CHUNK_SIZE", 1200), + help="chunk chunk size default 1200", + ) + + parser.add_argument( + "--chunk_overlap_size", + default=get_env_value("CHUNK_OVERLAP_SIZE", 100), + help="chunk overlap size default 100", + ) + + def timeout_type(value): + if value is None or value == "None": + return None + return int(value) + + parser.add_argument( + "--timeout", + default=get_env_value("TIMEOUT", None, timeout_type), + type=timeout_type, + help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", + ) + + # RAG configuration + parser.add_argument( + "--max-async", + type=int, + default=get_env_value("MAX_ASYNC", 4, int), + help="Maximum async operations (default: from env or 4)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=get_env_value("MAX_TOKENS", 32768, int), + help="Maximum token size (default: from env or 32768)", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=get_env_value("EMBEDDING_DIM", 1024, int), + help="Embedding dimensions (default: from env or 1024)", + ) + parser.add_argument( + "--max-embed-tokens", + type=int, + default=get_env_value("MAX_EMBED_TOKENS", 8192, int), + help="Maximum embedding token size (default: from env or 8192)", + ) + + # Logging configuration + parser.add_argument( + "--log-level", + default=get_env_value("LOG_LEVEL", "INFO"), + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: from env or INFO)", + ) + + parser.add_argument( + "--key", + type=str, + default=get_env_value("LIGHTRAG_API_KEY", None), + help="API key for authentication. This protects lightrag server against unauthorized access", + ) + + # Optional https parameters + parser.add_argument( + "--ssl", + action="store_true", + default=get_env_value("SSL", False, bool), + help="Enable HTTPS (default: from env or False)", + ) + parser.add_argument( + "--ssl-certfile", + default=get_env_value("SSL_CERTFILE", None), + help="Path to SSL certificate file (required if --ssl is enabled)", + ) + parser.add_argument( + "--ssl-keyfile", + default=get_env_value("SSL_KEYFILE", None), + help="Path to SSL private key file (required if --ssl is enabled)", + ) + parser.add_argument( + "--auto-scan-at-startup", + action="store_true", + default=False, + help="Enable automatic scanning when the program starts", + ) + + parser.add_argument( + "--history-turns", + type=int, + default=get_env_value("HISTORY_TURNS", 3, int), + help="Number of conversation history turns to include (default: from env or 3)", + ) + + # Search parameters + parser.add_argument( + "--top-k", + type=int, + default=get_env_value("TOP_K", 60, int), + help="Number of most similar results to return (default: from env or 60)", + ) + parser.add_argument( + "--cosine-threshold", + type=float, + default=get_env_value("COSINE_THRESHOLD", 0.2, float), + help="Cosine similarity threshold (default: from env or 0.4)", + ) + + # Ollama model name + parser.add_argument( + "--simulated-model-name", + type=str, + default=get_env_value( + "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL + ), + help="Number of conversation history turns to include (default: from env or 3)", + ) + + # Namespace + parser.add_argument( + "--namespace-prefix", + type=str, + default=get_env_value("NAMESPACE_PREFIX", ""), + help="Prefix of the namespace", + ) + + parser.add_argument( + "--verbose", + type=bool, + default=get_env_value("VERBOSE", False, bool), + help="Verbose debug output(default: from env or false)", + ) + + args = parser.parse_args() + + # convert relative path to absolute path + args.working_dir = os.path.abspath(args.working_dir) + args.input_dir = os.path.abspath(args.input_dir) + + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name + + return args + + +def display_splash_screen(args: argparse.Namespace) -> None: + """ + Display a colorful splash screen showing LightRAG server configuration + + Args: + args: Parsed command line arguments + """ + # Banner + ASCIIColors.cyan(f""" + ╔══════════════════════════════════════════════════════════════╗ + ║ 🚀 LightRAG Server v{__api_version__} ║ + ║ Fast, Lightweight RAG Server Implementation ║ + ╚══════════════════════════════════════════════════════════════╝ + """) + + # Server Configuration + ASCIIColors.magenta("\n📡 Server Configuration:") + ASCIIColors.white(" ├─ Host: ", end="") + ASCIIColors.yellow(f"{args.host}") + ASCIIColors.white(" ├─ Port: ", end="") + ASCIIColors.yellow(f"{args.port}") + ASCIIColors.white(" ├─ CORS Origins: ", end="") + ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") + ASCIIColors.white(" ├─ SSL Enabled: ", end="") + ASCIIColors.yellow(f"{args.ssl}") + ASCIIColors.white(" └─ API Key: ", end="") + ASCIIColors.yellow("Set" if args.key else "Not Set") + if args.ssl: + ASCIIColors.white(" ├─ SSL Cert: ", end="") + ASCIIColors.yellow(f"{args.ssl_certfile}") + ASCIIColors.white(" └─ SSL Key: ", end="") + ASCIIColors.yellow(f"{args.ssl_keyfile}") + + # Directory Configuration + ASCIIColors.magenta("\n📂 Directory Configuration:") + ASCIIColors.white(" ├─ Working Directory: ", end="") + ASCIIColors.yellow(f"{args.working_dir}") + ASCIIColors.white(" └─ Input Directory: ", end="") + ASCIIColors.yellow(f"{args.input_dir}") + + # LLM Configuration + ASCIIColors.magenta("\n🤖 LLM Configuration:") + ASCIIColors.white(" ├─ Binding: ", end="") + ASCIIColors.yellow(f"{args.llm_binding}") + ASCIIColors.white(" ├─ Host: ", end="") + ASCIIColors.yellow(f"{args.llm_binding_host}") + ASCIIColors.white(" └─ Model: ", end="") + ASCIIColors.yellow(f"{args.llm_model}") + + # Embedding Configuration + ASCIIColors.magenta("\n📊 Embedding Configuration:") + ASCIIColors.white(" ├─ Binding: ", end="") + ASCIIColors.yellow(f"{args.embedding_binding}") + ASCIIColors.white(" ├─ Host: ", end="") + ASCIIColors.yellow(f"{args.embedding_binding_host}") + ASCIIColors.white(" ├─ Model: ", end="") + ASCIIColors.yellow(f"{args.embedding_model}") + ASCIIColors.white(" └─ Dimensions: ", end="") + ASCIIColors.yellow(f"{args.embedding_dim}") + + # RAG Configuration + ASCIIColors.magenta("\n⚙️ RAG Configuration:") + ASCIIColors.white(" ├─ Max Async Operations: ", end="") + ASCIIColors.yellow(f"{args.max_async}") + ASCIIColors.white(" ├─ Max Tokens: ", end="") + ASCIIColors.yellow(f"{args.max_tokens}") + ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") + ASCIIColors.yellow(f"{args.max_embed_tokens}") + ASCIIColors.white(" ├─ Chunk Size: ", end="") + ASCIIColors.yellow(f"{args.chunk_size}") + ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="") + ASCIIColors.yellow(f"{args.chunk_overlap_size}") + ASCIIColors.white(" ├─ History Turns: ", end="") + ASCIIColors.yellow(f"{args.history_turns}") + ASCIIColors.white(" ├─ Cosine Threshold: ", end="") + ASCIIColors.yellow(f"{args.cosine_threshold}") + ASCIIColors.white(" └─ Top-K: ", end="") + ASCIIColors.yellow(f"{args.top_k}") + + # System Configuration + ASCIIColors.magenta("\n💾 Storage Configuration:") + ASCIIColors.white(" ├─ KV Storage: ", end="") + ASCIIColors.yellow(f"{args.kv_storage}") + ASCIIColors.white(" ├─ Vector Storage: ", end="") + ASCIIColors.yellow(f"{args.vector_storage}") + ASCIIColors.white(" ├─ Graph Storage: ", end="") + ASCIIColors.yellow(f"{args.graph_storage}") + ASCIIColors.white(" └─ Document Status Storage: ", end="") + ASCIIColors.yellow(f"{args.doc_status_storage}") + + ASCIIColors.magenta("\n🛠️ System Configuration:") + ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") + ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") + ASCIIColors.white(" ├─ Log Level: ", end="") + ASCIIColors.yellow(f"{args.log_level}") + ASCIIColors.white(" ├─ Verbose Debug: ", end="") + ASCIIColors.yellow(f"{args.verbose}") + ASCIIColors.white(" └─ Timeout: ", end="") + ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") + + # Server Status + ASCIIColors.green("\n✨ Server starting up...\n") + + # Server Access Information + protocol = "https" if args.ssl else "http" + if args.host == "0.0.0.0": + ASCIIColors.magenta("\n🌐 Server Access Information:") + ASCIIColors.white(" ├─ Local Access: ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}") + ASCIIColors.white(" ├─ Remote Access: ", end="") + ASCIIColors.yellow(f"{protocol}://:{args.port}") + ASCIIColors.white(" ├─ API Documentation (local): ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs") + ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc") + ASCIIColors.white(" └─ WebUI (local): ", end="") + ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui") + + ASCIIColors.yellow("\n📝 Note:") + ASCIIColors.white(""" Since the server is running on 0.0.0.0: + - Use 'localhost' or '127.0.0.1' for local access + - Use your machine's IP address for remote access + - To find your IP address: + • Windows: Run 'ipconfig' in terminal + • Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal + """) + else: + base_url = f"{protocol}://{args.host}:{args.port}" + ASCIIColors.magenta("\n🌐 Server Access Information:") + ASCIIColors.white(" ├─ Base URL: ", end="") + ASCIIColors.yellow(f"{base_url}") + ASCIIColors.white(" ├─ API Documentation: ", end="") + ASCIIColors.yellow(f"{base_url}/docs") + ASCIIColors.white(" └─ Alternative Documentation: ", end="") + ASCIIColors.yellow(f"{base_url}/redoc") + + # Usage Examples + ASCIIColors.magenta("\n📚 Quick Start Guide:") + ASCIIColors.cyan(""" + 1. Access the Swagger UI: + Open your browser and navigate to the API documentation URL above + + 2. API Authentication:""") + if args.key: + ASCIIColors.cyan(""" Add the following header to your requests: + X-API-Key: + """) + else: + ASCIIColors.cyan(" No authentication required\n") + + ASCIIColors.cyan(""" 3. Basic Operations: + - POST /upload_document: Upload new documents to RAG + - POST /query: Query your document collection + - GET /collections: List available collections + + 4. Monitor the server: + - Check server logs for detailed operation information + - Use healthcheck endpoint: GET /health + """) + + # Security Notice + if args.key: + ASCIIColors.yellow("\n⚠️ Security Notice:") + ASCIIColors.white(""" API Key authentication is enabled. + Make sure to include the X-API-Key header in all your requests. + """) + + ASCIIColors.green("Server is ready to accept connections! 🚀\n") + + # Ensure splash output flush to system log + sys.stdout.flush() diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index b6133a4c..63a295cd 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -48,11 +48,20 @@ class JsonDocStatusStorage(DocStatusStorage): self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == status.value - } + result = {} + for k, v in self._data.items(): + if v["status"] == status.value: + try: + # Make a copy of the data to avoid modifying the original + data = v.copy() + # If content is missing, use content_summary as content + if "content" not in data and "content_summary" in data: + data["content"] = data["content_summary"] + result[k] = DocProcessingStatus(**data) + except KeyError as e: + logger.error(f"Missing required field for document {k}: {e}") + continue + return result async def index_done_callback(self) -> None: write_json(self._data, self._file_name) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b3ad327d..12025469 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -263,9 +263,8 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): - logger.setLevel(self.log_level) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) - set_logger(self.log_file_path) + set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") if not os.path.exists(self.working_dir): diff --git a/lightrag/utils.py b/lightrag/utils.py index d402d14c..ae7e8dce 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -57,11 +57,17 @@ logger = logging.getLogger("lightrag") logging.getLogger("httpx").setLevel(logging.WARNING) -def set_logger(log_file: str): - logger.setLevel(logging.DEBUG) +def set_logger(log_file: str, level: int = logging.DEBUG): + """Set up file logging with the specified level. + + Args: + log_file: Path to the log file + level: Logging level (e.g. logging.DEBUG, logging.INFO) + """ + logger.setLevel(level) file_handler = logging.FileHandler(log_file, encoding="utf-8") - file_handler.setLevel(logging.DEBUG) + file_handler.setLevel(level) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/requirements.txt b/requirements.txt index 03d93aa3..a1a1157e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,6 @@ aiohttp configparser - -# database packages -networkx +future # Basic modules numpy