Merge branch 'main' into main

This commit is contained in:
Alex Z
2025-04-05 15:27:59 -07:00
committed by GitHub
77 changed files with 5920 additions and 5192 deletions

View File

@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
```
JsonKVStorage JsonFile(默认)
MongoKVStorage MogonDB
RedisKVStorage Redis
TiDBKVStorage TiDB
PGKVStorage Postgres
OracleKVStorage Oracle
RedisKVStorage Redis
MongoKVStorage MogonDB
```
* GRAPH_STORAGE 支持的实现名称
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
```
NetworkXStorage NetworkX(默认)
Neo4JStorage Neo4J
MongoGraphStorage MongoDB
TiDBGraphStorage TiDB
AGEStorage AGE
GremlinStorage Gremlin
PGGraphStorage Postgres
OracleGraphStorage Postgres
AGEStorage AGE
```
* VECTOR_STORAGE 支持的实现名称
```
NanoVectorDBStorage NanoVector(默认)
PGVectorStorage Postgres
MilvusVectorDBStorge Milvus
ChromaVectorDBStorage Chroma
TiDBVectorDBStorage TiDB
PGVectorStorage Postgres
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
OracleVectorDBStorage Oracle
MongoVectorDBStorage MongoDB
```

View File

@@ -302,11 +302,9 @@ Each storage type have servals implementations:
```
JsonKVStorage JsonFile(default)
MongoKVStorage MogonDB
RedisKVStorage Redis
TiDBKVStorage TiDB
PGKVStorage Postgres
OracleKVStorage Oracle
RedisKVStorage Redis
MongoKVStorage MogonDB
```
* GRAPH_STORAGE supported implement-name
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
```
NetworkXStorage NetworkX(defualt)
Neo4JStorage Neo4J
MongoGraphStorage MongoDB
TiDBGraphStorage TiDB
AGEStorage AGE
GremlinStorage Gremlin
PGGraphStorage Postgres
OracleGraphStorage Postgres
AGEStorage AGE
```
* VECTOR_STORAGE supported implement-name
```
NanoVectorDBStorage NanoVector(default)
MilvusVectorDBStorage Milvus
ChromaVectorDBStorage Chroma
TiDBVectorDBStorage TiDB
PGVectorStorage Postgres
MilvusVectorDBStorge Milvus
ChromaVectorDBStorage Chroma
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
OracleVectorDBStorage Oracle
MongoVectorDBStorage MongoDB
```

View File

@@ -1 +1 @@
__api_version__ = "1.2.8"
__api_version__ = "0136"

View File

@@ -1,9 +1,11 @@
import os
from datetime import datetime, timedelta
import jwt
from dotenv import load_dotenv
from fastapi import HTTPException, status
from pydantic import BaseModel
from dotenv import load_dotenv
from .config import global_args
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
class AuthHandler:
def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
self.algorithm = "HS256"
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
self.guest_expire_hours = int(os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2))
self.secret = global_args.token_secret
self.algorithm = global_args.jwt_algorithm
self.expire_hours = global_args.token_expire_hours
self.guest_expire_hours = global_args.guest_token_expire_hours
self.accounts = {}
auth_accounts = os.getenv("AUTH_ACCOUNTS")
auth_accounts = global_args.auth_accounts
if auth_accounts:
for account in auth_accounts.split(","):
username, password = account.split(":", 1)

335
lightrag/api/config.py Normal file
View File

@@ -0,0 +1,335 @@
"""
Configs for the LightRAG API.
"""
import os
import argparse
import logging
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
ollama_server_infos = OllamaServerInfos()
class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
) # fallback to ollama if unknown
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
"""
Get value from environment variable with type conversion
Args:
env_key (str): Environment variable key
default (any): Default value if env variable is not set
value_type (type): Type to convert the value to
Returns:
any: Converted value from environment or default
"""
value = os.getenv(env_key)
if value is None:
return default
if value_type is bool:
return value.lower() in ("true", "1", "yes", "t", "on")
try:
return value_type(value)
except ValueError:
return default
def parse_args() -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Args:
is_uvicorn_mode: Whether running under uvicorn mode
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration
parser.add_argument(
"--host",
default=get_env_value("HOST", "0.0.0.0"),
help="Server host (default: from env or 0.0.0.0)",
)
parser.add_argument(
"--port",
type=int,
default=get_env_value("PORT", 9621, int),
help="Server port (default: from env or 9621)",
)
# Directory configuration
parser.add_argument(
"--working-dir",
default=get_env_value("WORKING_DIR", "./rag_storage"),
help="Working directory for RAG storage (default: from env or ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default=get_env_value("INPUT_DIR", "./inputs"),
help="Directory containing input documents (default: from env or ./inputs)",
)
def timeout_type(value):
if value is None:
return 150
if value is None or value == "None":
return None
return int(value)
parser.add_argument(
"--timeout",
default=get_env_value("TIMEOUT", None, timeout_type),
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async",
type=int,
default=get_env_value("MAX_ASYNC", 4, int),
help="Maximum async operations (default: from env or 4)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=get_env_value("MAX_TOKENS", 32768, int),
help="Maximum token size (default: from env or 32768)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default=get_env_value("LOG_LEVEL", "INFO"),
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: from env or INFO)",
)
parser.add_argument(
"--verbose",
action="store_true",
default=get_env_value("VERBOSE", False, bool),
help="Enable verbose debug output(only valid for DEBUG log-level)",
)
parser.add_argument(
"--key",
type=str,
default=get_env_value("LIGHTRAG_API_KEY", None),
help="API key for authentication. This protects lightrag server against unauthorized access",
)
# Optional https parameters
parser.add_argument(
"--ssl",
action="store_true",
default=get_env_value("SSL", False, bool),
help="Enable HTTPS (default: from env or False)",
)
parser.add_argument(
"--ssl-certfile",
default=get_env_value("SSL_CERTFILE", None),
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=get_env_value("SSL_KEYFILE", None),
help="Path to SSL private key file (required if --ssl is enabled)",
)
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", 3, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)",
)
# Ollama model name
parser.add_argument(
"--simulated-model-name",
type=str,
default=get_env_value(
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Namespace
parser.add_argument(
"--namespace-prefix",
type=str,
default=get_env_value("NAMESPACE_PREFIX", ""),
help="Prefix of the namespace",
)
parser.add_argument(
"--auto-scan-at-startup",
action="store_true",
default=False,
help="Enable automatic scanning when the program starts",
)
# Server workers configuration
parser.add_argument(
"--workers",
type=int,
default=get_env_value("WORKERS", 1, int),
help="Number of worker processes (default: from env or 1)",
)
# LLM and embedding bindings
parser.add_argument(
"--llm-binding",
type=str,
default=get_env_value("LLM_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
help="LLM binding type (default: from env or ollama)",
)
parser.add_argument(
"--embedding-binding",
type=str,
default=get_env_value("EMBEDDING_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "azure_openai"],
help="Embedding binding type (default: from env or ollama)",
)
args = parser.parse_args()
# convert relative path to absolute path
args.working_dir = os.path.abspath(args.working_dir)
args.input_dir = os.path.abspath(args.input_dir)
# Inject storage configuration from environment variables
args.kv_storage = get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
)
args.doc_status_storage = get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
)
args.graph_storage = get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
)
args.vector_storage = get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
)
# Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
# Handle openai-ollama special case
if args.llm_binding == "openai-ollama":
args.llm_binding = "openai"
args.embedding_binding = "ollama"
args.llm_binding_host = get_env_value(
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
)
args.embedding_binding_host = get_env_value(
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
)
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
# Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
# Inject chunk configuration
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
)
# Inject LLM temperature configuration
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
# Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
# Add environment variables that were previously read directly
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
# For JWT Auth
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args
def update_uvicorn_mode_config():
# If in uvicorn mode and workers > 1, force it to 1 and log warning
if global_args.workers > 1:
original_workers = global_args.workers
global_args.workers = 1
# Log warning directly here
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
)
global_args = parse_args()

View File

@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
from dotenv import load_dotenv
from lightrag.api.utils_api import (
get_combined_auth_dependency,
parse_args,
get_default_host,
display_splash_screen,
check_env_file,
)
from .config import (
global_args,
update_uvicorn_mode_config,
get_default_host,
)
import sys
from lightrag import LightRAG, __version__ as core_version
from lightrag.api import __api_version__
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
webui_title = os.getenv("WEBUI_TITLE")
webui_description = os.getenv("WEBUI_DESCRIPTION")
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
@@ -164,10 +171,10 @@ def create_app(args):
app = FastAPI(**app_kwargs)
def get_cors_origins():
"""Get allowed origins from environment variable
"""Get allowed origins from global_args
Returns a list of allowed origins, defaults to ["*"] if not set
"""
origins_str = os.getenv("CORS_ORIGINS", "*")
origins_str = global_args.cors_origins
if origins_str == "*":
return ["*"]
return [origin.strip() for origin in origins_str.split(",")]
@@ -315,9 +322,10 @@ def create_app(args):
"similarity_threshold": 0.95,
"use_llm_check": False,
},
namespace_prefix=args.namespace_prefix,
# namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
addon_params={"language": args.summary_language},
)
else: # azure_openai
rag = LightRAG(
@@ -345,9 +353,10 @@ def create_app(args):
"similarity_threshold": 0.95,
"use_llm_check": False,
},
namespace_prefix=args.namespace_prefix,
# namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert,
addon_params={"language": args.summary_language},
)
# Add routes
@@ -381,6 +390,8 @@ def create_app(args):
"message": "Authentication is disabled. Using guest access.",
"core_version": core_version,
"api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
}
return {
@@ -388,6 +399,8 @@ def create_app(args):
"auth_mode": "enabled",
"core_version": core_version,
"api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
}
@app.post("/login")
@@ -404,6 +417,8 @@ def create_app(args):
"message": "Authentication is disabled. Using guest access.",
"core_version": core_version,
"api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
}
username = form_data.username
if auth_handler.accounts.get(username) != form_data.password:
@@ -421,6 +436,8 @@ def create_app(args):
"auth_mode": "enabled",
"core_version": core_version,
"api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
}
@app.get("/health", dependencies=[Depends(combined_auth)])
@@ -454,10 +471,12 @@ def create_app(args):
"vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
},
"core_version": core_version,
"api_version": __api_version__,
"auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False),
"core_version": core_version,
"api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
}
except Exception as e:
logger.error(f"Error getting health status: {str(e)}")
@@ -490,7 +509,7 @@ def create_app(args):
def get_application(args=None):
"""Factory function for creating the FastAPI application"""
if args is None:
args = parse_args()
args = global_args
return create_app(args)
@@ -611,30 +630,31 @@ def main():
# Configure logging before parsing args
configure_logging()
args = parse_args(is_uvicorn_mode=True)
display_splash_screen(args)
update_uvicorn_mode_config()
display_splash_screen(global_args)
# Create application instance directly instead of using factory function
app = create_app(args)
app = create_app(global_args)
# Start Uvicorn in single process mode
uvicorn_config = {
"app": app, # Pass application instance directly instead of string path
"host": args.host,
"port": args.port,
"host": global_args.host,
"port": global_args.port,
"log_config": None, # Disable default config
}
if args.ssl:
if global_args.ssl:
uvicorn_config.update(
{
"ssl_certfile": args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile,
"ssl_certfile": global_args.ssl_certfile,
"ssl_keyfile": global_args.ssl_keyfile,
}
)
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
print(
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
)
uvicorn.run(**uvicorn_config)

View File

@@ -10,16 +10,14 @@ import traceback
import pipmaster as pm
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Literal
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.api.utils_api import (
get_combined_auth_dependency,
global_args,
)
from lightrag.api.utils_api import get_combined_auth_dependency
from ..config import global_args
router = APIRouter(
prefix="/documents",
@@ -30,7 +28,37 @@ router = APIRouter(
temp_prefix = "__tmp__"
class ScanResponse(BaseModel):
"""Response model for document scanning operation
Attributes:
status: Status of the scanning operation
message: Optional message with additional details
"""
status: Literal["scanning_started"] = Field(
description="Status of the scanning operation"
)
message: Optional[str] = Field(
default=None, description="Additional details about the scanning operation"
)
class Config:
json_schema_extra = {
"example": {
"status": "scanning_started",
"message": "Scanning process has been initiated in the background",
}
}
class InsertTextRequest(BaseModel):
"""Request model for inserting a single text document
Attributes:
text: The text content to be inserted into the RAG system
"""
text: str = Field(
min_length=1,
description="The text to insert",
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
def strip_after(cls, text: str) -> str:
return text.strip()
class Config:
json_schema_extra = {
"example": {
"text": "This is a sample text to be inserted into the RAG system."
}
}
class InsertTextsRequest(BaseModel):
"""Request model for inserting multiple text documents
Attributes:
texts: List of text contents to be inserted into the RAG system
"""
texts: list[str] = Field(
min_length=1,
description="The texts to insert",
@@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel):
def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts]
class Config:
json_schema_extra = {
"example": {
"texts": [
"This is the first text to be inserted.",
"This is the second text to be inserted.",
]
}
}
class InsertResponse(BaseModel):
status: str = Field(description="Status of the operation")
"""Response model for document insertion operations
Attributes:
status: Status of the operation (success, duplicated, partial_success, failure)
message: Detailed message describing the operation result
"""
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
description="Status of the operation"
)
message: str = Field(description="Message describing the operation result")
class Config:
json_schema_extra = {
"example": {
"status": "success",
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
}
}
class ClearDocumentsResponse(BaseModel):
"""Response model for document clearing operation
Attributes:
status: Status of the clear operation
message: Detailed message describing the operation result
"""
status: Literal["success", "partial_success", "busy", "fail"] = Field(
description="Status of the clear operation"
)
message: str = Field(description="Message describing the operation result")
class Config:
json_schema_extra = {
"example": {
"status": "success",
"message": "All documents cleared successfully. Deleted 15 files.",
}
}
class ClearCacheRequest(BaseModel):
"""Request model for clearing cache
Attributes:
modes: Optional list of cache modes to clear
"""
modes: Optional[
List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
] = Field(
default=None,
description="Modes of cache to clear. If None, clears all cache.",
)
class Config:
json_schema_extra = {"example": {"modes": ["default", "naive"]}}
class ClearCacheResponse(BaseModel):
"""Response model for cache clearing operation
Attributes:
status: Status of the clear operation
message: Detailed message describing the operation result
"""
status: Literal["success", "fail"] = Field(
description="Status of the clear operation"
)
message: str = Field(description="Message describing the operation result")
class Config:
json_schema_extra = {
"example": {
"status": "success",
"message": "Successfully cleared cache for modes: ['default', 'naive']",
}
}
"""Response model for document status
Attributes:
id: Document identifier
content_summary: Summary of document content
content_length: Length of document content
status: Current processing status
created_at: Creation timestamp (ISO format string)
updated_at: Last update timestamp (ISO format string)
chunks_count: Number of chunks (optional)
error: Error message if any (optional)
metadata: Additional metadata (optional)
file_path: Path to the document file
"""
class DocStatusResponse(BaseModel):
@staticmethod
@@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel):
return dt
return dt.isoformat()
"""Response model for document status
id: str = Field(description="Document identifier")
content_summary: str = Field(description="Summary of document content")
content_length: int = Field(description="Length of document content in characters")
status: DocStatus = Field(description="Current processing status")
created_at: str = Field(description="Creation timestamp (ISO format string)")
updated_at: str = Field(description="Last update timestamp (ISO format string)")
chunks_count: Optional[int] = Field(
default=None, description="Number of chunks the document was split into"
)
error: Optional[str] = Field(
default=None, description="Error message if processing failed"
)
metadata: Optional[dict[str, Any]] = Field(
default=None, description="Additional metadata about the document"
)
file_path: str = Field(description="Path to the document file")
Attributes:
id: Document identifier
content_summary: Summary of document content
content_length: Length of document content
status: Current processing status
created_at: Creation timestamp (ISO format string)
updated_at: Last update timestamp (ISO format string)
chunks_count: Number of chunks (optional)
error: Error message if any (optional)
metadata: Additional metadata (optional)
"""
id: str
content_summary: str
content_length: int
status: DocStatus
created_at: str
updated_at: str
chunks_count: Optional[int] = None
error: Optional[str] = None
metadata: Optional[dict[str, Any]] = None
file_path: str
class Config:
json_schema_extra = {
"example": {
"id": "doc_123456",
"content_summary": "Research paper on machine learning",
"content_length": 15240,
"status": "PROCESSED",
"created_at": "2025-03-31T12:34:56",
"updated_at": "2025-03-31T12:35:30",
"chunks_count": 12,
"error": None,
"metadata": {"author": "John Doe", "year": 2025},
"file_path": "research_paper.pdf",
}
}
class DocsStatusesResponse(BaseModel):
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
"""Response model for document statuses
Attributes:
statuses: Dictionary mapping document status to lists of document status responses
"""
statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
default_factory=dict,
description="Dictionary mapping document status to lists of document status responses",
)
class Config:
json_schema_extra = {
"example": {
"statuses": {
"PENDING": [
{
"id": "doc_123",
"content_summary": "Pending document",
"content_length": 5000,
"status": "PENDING",
"created_at": "2025-03-31T10:00:00",
"updated_at": "2025-03-31T10:00:00",
"file_path": "pending_doc.pdf",
}
],
"PROCESSED": [
{
"id": "doc_456",
"content_summary": "Processed document",
"content_length": 8000,
"status": "PROCESSED",
"created_at": "2025-03-31T09:00:00",
"updated_at": "2025-03-31T09:05:00",
"chunks_count": 8,
"file_path": "processed_doc.pdf",
}
],
}
}
}
class PipelineStatusResponse(BaseModel):
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
)
return False
case ".pdf":
if global_args["main_args"].document_loading_engine == "DOCLING":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if global_args["main_args"].document_loading_engine == "DOCLING":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if global_args["main_args"].document_loading_engine == "DOCLING":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
if hasattr(shape, "text"):
content += shape.text + "\n"
case ".xlsx":
if global_args["main_args"].document_loading_engine == "DOCLING":
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
await rag.apipeline_process_enqueue_documents()
# TODO: deprecate after /insert_file is removed
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
"""Save the uploaded file to a temporary location
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
if not new_files:
return
# Get MAX_PARALLEL_INSERT from global_args["main_args"]
max_parallel = global_args["main_args"].max_parallel_insert
# Get MAX_PARALLEL_INSERT from global_args
max_parallel = global_args.max_parallel_insert
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
batch_size = 2 * max_parallel
@@ -509,7 +704,9 @@ def create_document_routes(
# Create combined auth dependency for document routes
combined_auth = get_combined_auth_dependency(api_key)
@router.post("/scan", dependencies=[Depends(combined_auth)])
@router.post(
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
)
async def scan_for_new_documents(background_tasks: BackgroundTasks):
"""
Trigger the scanning process for new documents.
@@ -519,13 +716,18 @@ def create_document_routes(
that fact.
Returns:
dict: A dictionary containing the scanning status
ScanResponse: A response object containing the scanning status
"""
# Start the scanning process in the background
background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"}
return ScanResponse(
status="scanning_started",
message="Scanning process has been initiated in the background",
)
@router.post("/upload", dependencies=[Depends(combined_auth)])
@router.post(
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
@@ -645,6 +847,7 @@ def create_document_routes(
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
# TODO: deprecated, use /upload instead
@router.post(
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
@@ -688,6 +891,7 @@ def create_document_routes(
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
# TODO: deprecated, use /upload instead
@router.post(
"/file_batch",
response_model=InsertResponse,
@@ -752,32 +956,186 @@ def create_document_routes(
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
"", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
)
async def clear_documents():
"""
Clear all documents from the RAG system.
This endpoint deletes all text chunks, entities vector database, and relationships
vector database, effectively clearing all documents from the RAG system.
This endpoint deletes all documents, entities, relationships, and files from the system.
It uses the storage drop methods to properly clean up all data and removes all files
from the input directory.
Returns:
InsertResponse: A response object containing the status and message.
ClearDocumentsResponse: A response object containing the status and message.
- status="success": All documents and files were successfully cleared.
- status="partial_success": Document clear job exit with some errors.
- status="busy": Operation could not be completed because the pipeline is busy.
- status="fail": All storage drop operations failed, with message
- message: Detailed information about the operation results, including counts
of deleted files and any errors encountered.
Raises:
HTTPException: If an error occurs during the clearing process (500).
HTTPException: Raised when a serious error occurs during the clearing process,
with status code 500 and error details in the detail field.
"""
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success", message="All documents cleared successfully"
from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
)
# Get pipeline status and lock
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status_lock = get_pipeline_status_lock()
# Check and set status with lock
async with pipeline_status_lock:
if pipeline_status.get("busy", False):
return ClearDocumentsResponse(
status="busy",
message="Cannot clear documents while pipeline is busy",
)
# Set busy to true
pipeline_status.update(
{
"busy": True,
"job_name": "Clearing Documents",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "Starting document clearing process",
}
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
pipeline_status["history_messages"].append(
"Starting document clearing process"
)
try:
# Use drop method to clear all data
drop_tasks = []
storages = [
rag.text_chunks,
rag.full_docs,
rag.entities_vdb,
rag.relationships_vdb,
rag.chunks_vdb,
rag.chunk_entity_relation_graph,
rag.doc_status,
]
# Log storage drop start
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(
"Starting to drop storage components"
)
for storage in storages:
if storage is not None:
drop_tasks.append(storage.drop())
# Wait for all drop tasks to complete
drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
# Check for errors and log results
errors = []
storage_success_count = 0
storage_error_count = 0
for i, result in enumerate(drop_results):
storage_name = storages[i].__class__.__name__
if isinstance(result, Exception):
error_msg = f"Error dropping {storage_name}: {str(result)}"
errors.append(error_msg)
logger.error(error_msg)
storage_error_count += 1
else:
logger.info(f"Successfully dropped {storage_name}")
storage_success_count += 1
# Log storage drop results
if "history_messages" in pipeline_status:
if storage_error_count > 0:
pipeline_status["history_messages"].append(
f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
)
else:
pipeline_status["history_messages"].append(
f"Successfully dropped all {storage_success_count} storage components"
)
# If all storage operations failed, return error status and don't proceed with file deletion
if storage_success_count == 0 and storage_error_count > 0:
error_message = "All storage drop operations failed. Aborting document clearing process."
logger.error(error_message)
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(error_message)
return ClearDocumentsResponse(status="fail", message=error_message)
# Log file deletion start
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(
"Starting to delete files in input directory"
)
# Delete all files in input_dir
deleted_files_count = 0
file_errors_count = 0
for file_path in doc_manager.input_dir.glob("**/*"):
if file_path.is_file():
try:
file_path.unlink()
deleted_files_count += 1
except Exception as e:
logger.error(f"Error deleting file {file_path}: {str(e)}")
file_errors_count += 1
# Log file deletion results
if "history_messages" in pipeline_status:
if file_errors_count > 0:
pipeline_status["history_messages"].append(
f"Deleted {deleted_files_count} files with {file_errors_count} errors"
)
errors.append(f"Failed to delete {file_errors_count} files")
else:
pipeline_status["history_messages"].append(
f"Successfully deleted {deleted_files_count} files"
)
# Prepare final result message
final_message = ""
if errors:
final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
status = "partial_success"
else:
final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
status = "success"
# Log final result
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(final_message)
# Return response based on results
return ClearDocumentsResponse(status=status, message=final_message)
except Exception as e:
logger.error(f"Error DELETE /documents: {str(e)}")
error_msg = f"Error clearing documents: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(error_msg)
raise HTTPException(status_code=500, detail=str(e))
finally:
# Reset busy status after completion
async with pipeline_status_lock:
pipeline_status["busy"] = False
completion_msg = "Document clearing process completed"
pipeline_status["latest_message"] = completion_msg
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(completion_msg)
@router.get(
"/pipeline_status",
@@ -850,7 +1208,9 @@ def create_document_routes(
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("", dependencies=[Depends(combined_auth)])
@router.get(
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
)
async def documents() -> DocsStatusesResponse:
"""
Get the status of all documents in the system.
@@ -908,4 +1268,57 @@ def create_document_routes(
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/clear_cache",
response_model=ClearCacheResponse,
dependencies=[Depends(combined_auth)],
)
async def clear_cache(request: ClearCacheRequest):
"""
Clear cache data from the LLM response cache storage.
This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
- "default" represents extraction cache.
- Other modes correspond to different query modes.
Args:
request (ClearCacheRequest): The request body containing optional modes to clear.
Returns:
ClearCacheResponse: A response object containing the status and message.
Raises:
HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
"""
try:
# Validate modes if provided
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
if request.modes and not all(mode in valid_modes for mode in request.modes):
invalid_modes = [
mode for mode in request.modes if mode not in valid_modes
]
raise HTTPException(
status_code=400,
detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
)
# Call the aclear_cache method
await rag.aclear_cache(request.modes)
# Prepare success message
if request.modes:
message = f"Successfully cleared cache for modes: {request.modes}"
else:
message = "Successfully cleared all cache"
return ClearCacheResponse(status="success", message=message)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error clearing cache: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
return router

View File

@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
"""
from typing import Optional
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from ..utils_api import get_combined_auth_dependency
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graphs", dependencies=[Depends(combined_auth)])
async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
label: str = Query(..., description="Label to get knowledge graph for"),
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
):
"""
Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
1. Hops(path) to the staring node take precedence
2. Followed by the degree of the nodes
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
label (str): Label of the starting node
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
max_nodes: Maxiumu nodes to return
Returns:
Dict[str, List[str]]: Knowledge graph for label
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
return await rag.get_knowledge_graph(
node_label=label,
max_depth=max_depth,
inclusive=inclusive,
min_degree=min_degree,
max_nodes=max_nodes,
)
return router

View File

@@ -7,14 +7,9 @@ import os
import sys
import signal
import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file
from lightrag.api.utils_api import display_splash_screen, check_env_file
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
from .config import global_args
def check_and_install_dependencies():
@@ -59,20 +54,17 @@ def main():
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # kill command
# Parse all arguments using parse_args
args = parse_args(is_uvicorn_mode=False)
# Display startup information
display_splash_screen(args)
display_splash_screen(global_args)
print("🚀 Starting LightRAG with Gunicorn")
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
print("🔍 Preloading app: Enabled")
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
print("\n\n" + "=" * 80)
print("MAIN PROCESS INITIALIZATION")
print(f"Process ID: {os.getpid()}")
print(f"Workers setting: {args.workers}")
print(f"Workers setting: {global_args.workers}")
print("=" * 80 + "\n")
# Import Gunicorn's StandaloneApplication
@@ -128,31 +120,43 @@ def main():
# Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
global_args.workers
if global_args.workers
else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
host = (
global_args.host
if global_args.host != "0.0.0.0"
else os.getenv("HOST", "0.0.0.0")
)
port = (
global_args.port
if global_args.port != 9621
else int(os.getenv("PORT", 9621))
)
gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
global_args.log_level.lower()
if global_args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = (
args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2))
global_args.timeout
if global_args.timeout * 2
else int(os.getenv("TIMEOUT", 150 * 2))
)
# Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in (
if global_args.ssl or os.getenv("SSL", "").lower() in (
"true",
"1",
"yes",
@@ -160,12 +164,14 @@ def main():
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
global_args.ssl_certfile
if global_args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
global_args.ssl_keyfile
if global_args.ssl_keyfile
else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module
@@ -190,13 +196,13 @@ def main():
# Import the application
from lightrag.api.lightrag_server import get_application
return get_application(args)
return get_application(global_args)
# Create the application
app = GunicornApp("")
# Force workers to be an integer and greater than 1 for multi-process mode
workers_count = int(args.workers)
workers_count = int(global_args.workers)
if workers_count > 1:
# Set a flag to indicate we're in the main process
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"

View File

@@ -7,15 +7,13 @@ import argparse
from typing import Optional, List, Tuple
import sys
from ascii_colors import ASCIIColors
import logging
from lightrag.api import __api_version__ as api_version
from lightrag import __version__ as core_version
from fastapi import HTTPException, Security, Request, status
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
from ..prompt import PROMPTS
from .config import ollama_server_infos, global_args
def check_env_file():
@@ -36,16 +34,8 @@ def check_env_file():
return True
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
global_args = {"main_args": None}
# Get whitelist paths from environment variable, only once during initialization
default_whitelist = "/health,/api/*"
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
# Get whitelist paths from global_args, only once during initialization
whitelist_paths = global_args.whitelist_paths.split(",")
# Pre-compile path matching patterns
whitelist_patterns: List[Tuple[str, bool]] = []
@@ -63,19 +53,6 @@ for path in whitelist_paths:
auth_configured = bool(auth_handler.accounts)
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
ollama_server_infos = OllamaServerInfos()
def get_combined_auth_dependency(api_key: Optional[str] = None):
"""
Create a combined authentication dependency that implements authentication logic
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
return combined_dependency
class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
) # fallback to ollama if unknown
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
"""
Get value from environment variable with type conversion
Args:
env_key (str): Environment variable key
default (any): Default value if env variable is not set
value_type (type): Type to convert the value to
Returns:
any: Converted value from environment or default
"""
value = os.getenv(env_key)
if value is None:
return default
if value_type is bool:
return value.lower() in ("true", "1", "yes", "t", "on")
try:
return value_type(value)
except ValueError:
return default
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Args:
is_uvicorn_mode: Whether running under uvicorn mode
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration
parser.add_argument(
"--host",
default=get_env_value("HOST", "0.0.0.0"),
help="Server host (default: from env or 0.0.0.0)",
)
parser.add_argument(
"--port",
type=int,
default=get_env_value("PORT", 9621, int),
help="Server port (default: from env or 9621)",
)
# Directory configuration
parser.add_argument(
"--working-dir",
default=get_env_value("WORKING_DIR", "./rag_storage"),
help="Working directory for RAG storage (default: from env or ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default=get_env_value("INPUT_DIR", "./inputs"),
help="Directory containing input documents (default: from env or ./inputs)",
)
def timeout_type(value):
if value is None:
return 150
if value is None or value == "None":
return None
return int(value)
parser.add_argument(
"--timeout",
default=get_env_value("TIMEOUT", None, timeout_type),
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async",
type=int,
default=get_env_value("MAX_ASYNC", 4, int),
help="Maximum async operations (default: from env or 4)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=get_env_value("MAX_TOKENS", 32768, int),
help="Maximum token size (default: from env or 32768)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default=get_env_value("LOG_LEVEL", "INFO"),
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: from env or INFO)",
)
parser.add_argument(
"--verbose",
action="store_true",
default=get_env_value("VERBOSE", False, bool),
help="Enable verbose debug output(only valid for DEBUG log-level)",
)
parser.add_argument(
"--key",
type=str,
default=get_env_value("LIGHTRAG_API_KEY", None),
help="API key for authentication. This protects lightrag server against unauthorized access",
)
# Optional https parameters
parser.add_argument(
"--ssl",
action="store_true",
default=get_env_value("SSL", False, bool),
help="Enable HTTPS (default: from env or False)",
)
parser.add_argument(
"--ssl-certfile",
default=get_env_value("SSL_CERTFILE", None),
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=get_env_value("SSL_KEYFILE", None),
help="Path to SSL private key file (required if --ssl is enabled)",
)
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", 3, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)",
)
# Ollama model name
parser.add_argument(
"--simulated-model-name",
type=str,
default=get_env_value(
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Namespace
parser.add_argument(
"--namespace-prefix",
type=str,
default=get_env_value("NAMESPACE_PREFIX", ""),
help="Prefix of the namespace",
)
parser.add_argument(
"--auto-scan-at-startup",
action="store_true",
default=False,
help="Enable automatic scanning when the program starts",
)
# Server workers configuration
parser.add_argument(
"--workers",
type=int,
default=get_env_value("WORKERS", 1, int),
help="Number of worker processes (default: from env or 1)",
)
# LLM and embedding bindings
parser.add_argument(
"--llm-binding",
type=str,
default=get_env_value("LLM_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
help="LLM binding type (default: from env or ollama)",
)
parser.add_argument(
"--embedding-binding",
type=str,
default=get_env_value("EMBEDDING_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "azure_openai"],
help="Embedding binding type (default: from env or ollama)",
)
args = parser.parse_args()
# If in uvicorn mode and workers > 1, force it to 1 and log warning
if is_uvicorn_mode and args.workers > 1:
original_workers = args.workers
args.workers = 1
# Log warning directly here
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
)
# convert relative path to absolute path
args.working_dir = os.path.abspath(args.working_dir)
args.input_dir = os.path.abspath(args.input_dir)
# Inject storage configuration from environment variables
args.kv_storage = get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
)
args.doc_status_storage = get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
)
args.graph_storage = get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
)
args.vector_storage = get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
)
# Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
# Handle openai-ollama special case
if args.llm_binding == "openai-ollama":
args.llm_binding = "openai"
args.embedding_binding = "ollama"
args.llm_binding_host = get_env_value(
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
)
args.embedding_binding_host = get_env_value(
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
)
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
# Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
# Inject chunk configuration
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
)
# Inject LLM temperature configuration
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
# Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
global_args["main_args"] = args
return args
def display_splash_screen(args: argparse.Namespace) -> None:
"""
Display a colorful splash screen showing LightRAG server configuration
@@ -489,7 +173,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
# Banner
ASCIIColors.cyan(f"""
╔══════════════════════════════════════════════════════════════╗
🚀 LightRAG Server v{core_version}/{api_version}
║ 🚀 LightRAG Server v{core_version}/{api_version}
║ Fast, Lightweight RAG Server Implementation ║
╚══════════════════════════════════════════════════════════════╝
""")
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.white(" ├─ Workers: ", end="")
ASCIIColors.yellow(f"{args.workers}")
ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
ASCIIColors.yellow(f"{args.cors_origins}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
ASCIIColors.yellow(f"{args.ssl}")
if args.ssl:
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" ├─ History Turns: ", end="")
ASCIIColors.yellow(f"{args.history_turns}")
ASCIIColors.white(" ─ API Key: ", end="")
ASCIIColors.white(" ─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set")
ASCIIColors.white(" └─ JWT Auth: ", end="")
ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
# Directory Configuration
ASCIIColors.magenta("\n📂 Directory Configuration:")
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.embedding_dim}")
# RAG Configuration
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
ASCIIColors.white(" ├─ Summary Language: ", end="")
ASCIIColors.yellow(f"{summary_language}")
ASCIIColors.yellow(f"{args.summary_language}")
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
ASCIIColors.yellow(f"{args.max_parallel_insert}")
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
protocol = "https" if args.ssl else "http"
if args.host == "0.0.0.0":
ASCIIColors.magenta("\n🌐 Server Access Information:")
ASCIIColors.white(" ├─ Local Access: ", end="")
ASCIIColors.white(" ├─ WebUI (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
ASCIIColors.white(" ├─ Remote Access: ", end="")
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
ASCIIColors.white(" ─ Alternative Documentation (local): ", end="")
ASCIIColors.white(" ─ Alternative Documentation (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
ASCIIColors.white(" └─ WebUI (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
ASCIIColors.yellow("\n📝 Note:")
ASCIIColors.white(""" Since the server is running on 0.0.0.0:
ASCIIColors.magenta("\n📝 Note:")
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
- Use 'localhost' or '127.0.0.1' for local access
- Use your machine's IP address for remote access
- To find your IP address:
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
else:
base_url = f"{protocol}://{args.host}:{args.port}"
ASCIIColors.magenta("\n🌐 Server Access Information:")
ASCIIColors.white(" ├─ Base URL: ", end="")
ASCIIColors.white(" ├─ WebUI (local): ", end="")
ASCIIColors.yellow(f"{base_url}")
ASCIIColors.white(" ├─ API Documentation: ", end="")
ASCIIColors.yellow(f"{base_url}/docs")
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
ASCIIColors.yellow(f"{base_url}/redoc")
# Usage Examples
ASCIIColors.magenta("\n📚 Quick Start Guide:")
ASCIIColors.cyan("""
1. Access the Swagger UI:
Open your browser and navigate to the API documentation URL above
2. API Authentication:""")
if args.key:
ASCIIColors.cyan(""" Add the following header to your requests:
X-API-Key: <your-api-key>
""")
else:
ASCIIColors.cyan(" No authentication required\n")
ASCIIColors.cyan(""" 3. Basic Operations:
- POST /upload_document: Upload new documents to RAG
- POST /query: Query your document collection
4. Monitor the server:
- Check server logs for detailed operation information
- Use healthcheck endpoint: GET /health
""")
# Security Notice
if args.key:
ASCIIColors.yellow("\n⚠️ Security Notice:")
ASCIIColors.white(""" API Key authentication is enabled.
Make sure to include the X-API-Key header in all your requests.
""")
if args.auth_accounts:
ASCIIColors.yellow("\n⚠️ Security Notice:")
ASCIIColors.white(""" JWT authentication is enabled.
Make sure to login before making the request, and include the 'Authorization' in the header.
""")
# Ensure splash output flush to system log
sys.stdout.flush()

File diff suppressed because one or more lines are too long

1345
lightrag/api/webui/assets/index-Cma7xY0-.js generated Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -8,8 +8,8 @@
<link rel="icon" type="image/svg+xml" href="logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title>
<script type="module" crossorigin src="/webui/assets/index-raheqJeu.js"></script>
<link rel="stylesheet" crossorigin href="/webui/assets/index-CD5HxTy1.css">
<script type="module" crossorigin src="/webui/assets/index-Cma7xY0-.js"></script>
<link rel="stylesheet" crossorigin href="/webui/assets/index-QU59h9JG.css">
</head>
<body>
<div id="root"></div>

View File

@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing"""
@abstractmethod
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This abstract method defines the contract for dropping all data from a storage implementation.
Each storage type must implement this method to:
1. Clear all data from memory and/or external storage
2. Remove any associated storage files if applicable
3. Reset the storage to its initial state
4. Handle cleanup of any resources
5. Notify other processes if necessary
6. This action should persistent the data to disk immediately.
Returns:
dict[str, str]: Operation status and message with the following format:
{
"status": str, # "success" or "error"
"message": str # "data dropped" on success, error details on failure
}
Implementation specific:
- On success: return {"status": "success", "message": "data dropped"}
- On failure: return {"status": "error", "message": "<error details>"}
- If not supported: return {"status": "error", "message": "unsupported"}
"""
@dataclass
class BaseVectorStorage(StorageNameSpace, ABC):
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Insert or update vectors in the storage."""
"""Insert or update vectors in the storage.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name."""
"""Delete a single entity by its name.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity."""
"""Delete relations for a given entity.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
"""
pass
@abstractmethod
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
ids: List of vector IDs to be deleted
"""
@dataclass
class BaseKVStorage(StorageNameSpace, ABC):
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
@abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Upsert data"""
"""Upsert data
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
@abstractmethod
async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed, or the cache mode is not supported
"""
@dataclass
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get an edge by its source and target node ids."""
"""Get node by its label identifier, return only node properties"""
@abstractmethod
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get all edges connected to a node."""
"""Get edge properties between two nodes"""
@abstractmethod
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""Delete a node from the graph."""
"""Delete a node from the graph.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod
async def delete_node(self, node_id: str) -> None:
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 3
self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000BFS if possible)
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
class DocStatus(str, Enum):
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Drop cache is not supported for Doc Status storage"""
return False
class StoragesStatus(str, Enum):
"""Storages status"""

View File

@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
"MongoKVStorage",
# "TiDBKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
# "AGEStorage",
# "MongoGraphStorage",
# "TiDBGraphStorage",
# "GremlinStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
# "TiDBVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
"implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
# "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
# "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
# "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
@@ -112,9 +90,6 @@ STORAGES = {
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
@@ -122,14 +97,14 @@ STORAGES = {
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
# "TiDBKVStorage": ".kg.tidb_impl",
# "TiDBVectorDBStorage": ".kg.tidb_impl",
# "TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
# "GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",

View File

@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
import psycopg # type: ignore
from psycopg.rows import namedtuple_row # type: ignore
from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
class AGEQueryException(Exception):
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
async def index_done_callback(self) -> None:
# AGES handles persistence automatically
pass
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all nodes and relationships in the graph.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
query = """
MATCH (n)
DETACH DELETE n
"""
await self._query(query)
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.graph_name}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -1,4 +1,5 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, final
import numpy as np
@@ -10,8 +11,8 @@ import pipmaster as pm
if not pm.is_installed("chromadb"):
pm.install("chromadb")
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from chromadb import HttpClient, PersistentClient # type: ignore
from chromadb.config import Settings # type: ignore
@final
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will delete all documents from the ChromaDB collection.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Get all IDs in the collection
result = self._collection.get(include=[])
if result and result["ids"] and len(result["ids"]) > 0:
# Delete all documents
self._collection.delete(ids=result["ids"])
logger.info(
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -11,16 +11,20 @@ import pipmaster as pm
from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseVectorStorage
if not pm.is_installed("faiss"):
pm.install("faiss")
import faiss # type: ignore
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
)
import faiss # type: ignore
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
if not pm.is_installed(FAISS_PACKAGE):
pm.install(FAISS_PACKAGE)
@final
@dataclass
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
async def delete(self, ids: list[str]):
"""
Delete vectors for the provided custom IDs.
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
to_remove = []
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
async def delete_entity(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
await self.delete([entity_id])
async def delete_entity_relation(self, entity_name: str) -> None:
"""
Delete relations for a given entity by scanning metadata.
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.debug(f"Searching relations for entity {entity_name}")
relations = []
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
results.append({**metadata, "id": metadata.get("__id__")})
return results
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will:
1. Remove the vector database storage file if it exists
2. Reinitialize the vector database client
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
This method will remove all vectors from the Faiss index and delete the storage files.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# Reset the index
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
# Remove storage files if they exist
if os.path.exists(self._faiss_index_file):
os.remove(self._faiss_index_file)
if os.path.exists(self._meta_file):
os.remove(self._meta_file)
self._id_to_meta = {}
self._load_faiss_index()
# Notify other processes
await set_all_update_flags(self.namespace)
self.storage_updated.value = False
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
if not pm.is_installed("gremlinpython"):
pm.install("gremlinpython")
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
from gremlin_python.driver import client, serializer # type: ignore
from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
from gremlin_python.driver.protocol import GremlinServerError # type: ignore
@final
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all nodes and relationships in the graph.
This function deletes all nodes with the specified graph name property,
which automatically removes all associated edges.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
query = f"""g
.V().has('graph', {self.graph_name})
.drop()
"""
await self._query(query)
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.graph_name}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
await clear_all_update_flags(self.namespace)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
if not data:
return
logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
async with self._storage_lock:
return self._data.get(id)
async def delete(self, doc_ids: list[str]):
async with self._storage_lock:
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await set_all_update_flags(self.namespace)
await self.index_done_callback()
async def delete(self, doc_ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
async def drop(self) -> None:
"""Drop the storage"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async with self._storage_lock:
self._data.clear()
await set_all_update_flags(self.namespace)
await self.index_done_callback()
any_deleted = False
for doc_id in doc_ids:
result = self._data.pop(doc_id, None)
if result is not None:
any_deleted = True
if any_deleted:
await set_all_update_flags(self.namespace)
async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources
This method will:
1. Clear all document status data from memory
2. Update flags to notify other processes
3. Trigger index_done_callback to save the empty state
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
self._data.clear()
await set_all_update_flags(self.namespace)
await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.namespace}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
if not data:
return
logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
await set_all_update_flags(self.namespace)
async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async with self._storage_lock:
any_deleted = False
for doc_id in ids:
self._data.pop(doc_id, None)
await set_all_update_flags(self.namespace)
await self.index_done_callback()
result = self._data.pop(doc_id, None)
if result is not None:
any_deleted = True
if any_deleted:
await set_all_update_flags(self.namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of cache mode to be drop from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
try:
await self.delete(modes)
return True
except Exception:
return False
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This action will persistent the data to disk immediately.
This method will:
1. Clear all data from memory
2. Update flags to notify other processes
3. Trigger index_done_callback to save the empty state
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
self._data.clear()
await set_all_update_flags(self.namespace)
await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.namespace}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
import configparser
from pymilvus import MilvusClient
from pymilvus import MilvusClient # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will delete all data from the Milvus collection.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Drop the collection and recreate it
if self._client.has_collection(self.namespace):
self._client.drop_collection(self.namespace)
# Recreate the collection
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
logger.info(
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
if not pm.is_installed("motor"):
pm.install("motor")
from motor.motor_asyncio import (
from motor.motor_asyncio import ( # type: ignore
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
from pymongo.operations import SearchIndexModel # type: ignore
from pymongo.errors import PyMongoError # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
# Mongo handles persistence automatically
pass
async def delete(self, ids: list[str]) -> None:
"""Delete documents with specified IDs
Args:
ids: List of document IDs to be deleted
"""
if not ids:
return
try:
result = await self._data.delete_many({"_id": {"$in": ids}})
logger.info(
f"Deleted {result.deleted_count} documents from {self.namespace}"
)
except PyMongoError as e:
logger.error(f"Error deleting documents from {self.namespace}: {e}")
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
bool: True if successful, False otherwise
"""
if not modes:
return False
try:
# Build regex pattern to match documents with the specified modes
pattern = f"^({'|'.join(modes)})_"
result = await self._data.delete_many({"_id": {"$regex": pattern}})
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
return True
except Exception as e:
logger.error(f"Error deleting cache by modes {modes}: {e}")
return False
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
result = await self._data.delete_many({})
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
@final
@dataclass
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
# Mongo handles persistence automatically
pass
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
result = await self._data.delete_many({})
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
@final
@dataclass
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
logger.debug(f"Successfully deleted edges: {edges}")
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
result = await self.collection.delete_many({})
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from graph {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping graph {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
@final
@dataclass
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection and recreating vector index.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
# Delete all documents
result = await self._data.delete_many({})
deleted_count = result.deleted_count
# Recreate vector index
await self.create_vector_index_if_not_exists()
logger.info(
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped and vector index recreated",
}
except PyMongoError as e:
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
collection_names = await db.list_collection_names()

View File

@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
ids: List of vector IDs to be deleted
"""
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
try:
client = await self._get_client()
storage = getattr(client, "_NanoVectorDB__storage")
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
client = await self._get_client()
return client.get(ids)
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will:
1. Remove the vector database storage file if it exists
2. Reinitialize the vector database client
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
This method is intended for use in scenarios where all data needs to be removed,
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# delete _client_file_name
if os.path.exists(self._client_file_name):
os.remove(self._client_file_name)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -1,9 +1,8 @@
import asyncio
import inspect
import os
import re
from dataclasses import dataclass
from typing import Any, final, Optional
from typing import Any, final
import numpy as np
import configparser
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
config = configparser.ConfigParser()
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
embedding_func=embedding_func,
)
self._driver = None
self._driver_lock = asyncio.Lock()
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get(
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
),
)
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
)
# Try to connect to the database
with GraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
) as _sync_driver:
for database in (DATABASE, None):
self._DATABASE = database
connected = False
# Try to connect to the database and create it if it doesn't exist
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try:
with _sync_driver.session(database=database) as session:
try:
session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
try:
async with self._driver.session(database=database) as session:
try:
result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
)
try:
with _sync_driver.session() as session:
session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if connected:
break
if connected:
# Create index for base nodes on entity_id if it doesn't exist
try:
async with self._driver.session(database=database) as session:
# Check if index exists first
check_query = """
CALL db.indexes() YIELD name, labelsOrTypes, properties
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
RETURN count(*) > 0 AS exists
"""
try:
check_result = await session.run(check_query)
record = await check_result.single()
await check_result.consume()
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
index_exists = record and record.get("exists", False)
async def close(self):
if not index_exists:
# Create index only if it doesn't exist
result = await session.run(
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
)
await result.consume()
logger.info(
f"Created index for base nodes on entity_id in {database}"
)
except Exception:
# Fallback if db.indexes() is not supported in this Neo4j version
result = await session.run(
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
)
await result.consume()
except Exception as e:
logger.warning(f"Failed to create index: {str(e)}")
break
async def finalize(self):
"""Close the Neo4j driver and release all resources"""
if self._driver:
await self._driver.close()
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits"""
await self.close()
await self.finalize()
async def index_done_callback(self) -> None:
# Noe4J handles persistence automatically
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
raise
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier.
"""Get node by its label identifier, return only node properties
Args:
node_id: The node label to look up
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
)
# Return default edge properties when no edge found
return {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
# Return None when no edge found
return None
finally:
await result.consume() # Ensure result is fully consumed
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
"""
properties = node_data
entity_type = properties["entity_type"]
entity_id = properties["entity_id"]
if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
async def execute_upsert(tx: AsyncManagedTransaction):
query = (
"""
MERGE (n:base {entity_id: $properties.entity_id})
MERGE (n:base {entity_id: $entity_id})
SET n += $properties
SET n:`%s`
"""
% entity_type
)
result = await tx.run(query, properties=properties)
result = await tx.run(
query, entity_id=node_id, properties=properties
)
logger.debug(
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
f"Upserted node with entity_id '{node_id}' and properties: {properties}"
)
await result.consume() # Ensure result is fully consumed
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
result = KnowledgeGraph()
seen_nodes = set()
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
) as session:
try:
if node_label == "*":
# First check total node count to determine if graph is truncated
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
try:
count_result = await session.run(count_query)
count_record = await count_result.single()
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
await count_result.consume()
# Run main query to get nodes with highest degree
main_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree
WHERE degree >= $min_degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
"""
result_set = await session.run(
main_query,
{"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
)
result_set = None
try:
result_set = await session.run(
main_query,
{"max_nodes": max_nodes},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
else:
# Main query uses partial matching
main_query = """
# return await self._robust_fallback(node_label, max_depth, max_nodes)
# First try without limit to check if we need to truncate
full_query = """
MATCH (start)
WHERE
CASE
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
ELSE start.entity_id = $entity_id
END
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
bfs: true
})
YIELD nodes, relationships
WITH start, nodes, relationships
WITH nodes, relationships, size(nodes) AS total_nodes
UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-()
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
ORDER BY
CASE
WHEN node = start THEN 3
WHEN EXISTS((start)--(node)) THEN 2
ELSE 1
END DESC,
degree DESC
LIMIT $max_nodes
WITH collect({node: node}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
WITH collect({node: node}) AS node_info, relationships, total_nodes
RETURN node_info, relationships, total_nodes
"""
result_set = await session.run(
main_query,
{
"max_nodes": MAX_GRAPH_NODES,
"entity_id": node_label,
"inclusive": inclusive,
"max_depth": max_depth,
"min_degree": min_degree,
},
)
try:
record = await result_set.single()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
# Try to get full result
full_result = None
try:
full_result = await session.run(
full_query,
{
"entity_id": node_label,
"max_depth": max_depth,
},
)
finally:
await result_set.consume() # Ensure result set is consumed
full_record = await full_result.single()
# If no record found, return empty KnowledgeGraph
if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}")
return result
# If record found, check node count
total_nodes = full_record["total_nodes"]
if total_nodes <= max_nodes:
# If node count is within limit, use full result directly
logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)"
)
record = full_record
else:
# If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True
logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
)
# Run limited query
limited_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
limit: $max_nodes,
bfs: true
})
YIELD nodes, relationships
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships
RETURN node_info, relationships
"""
result_set = None
try:
result_set = await session.run(
limited_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
finally:
if full_result:
await full_result.consume()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}")
@@ -767,110 +837,28 @@ class Neo4JStorage(BaseGraphStorage):
logger.warning(
"Neo4j: falling back to basic Cypher recursive search..."
)
if inclusive:
logger.warning(
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
)
return await self._robust_fallback(
node_label, max_depth, min_degree
return await self._robust_fallback(node_label, max_depth, max_nodes)
else:
logger.warning(
"Neo4j: APOC plugin error with wildcard query, returning empty result"
)
return result
async def _robust_fallback(
self, node_label: str, max_depth: int, min_degree: int = 0
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
"""
Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures.
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
"""
from collections import deque
result = KnowledgeGraph()
visited_nodes = set()
visited_edges = set()
async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits
if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= MAX_GRAPH_NODES:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
return
# Check if node already visited
if node.id in visited_nodes:
return
# Get all edges and target nodes
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=node.id)
# Get all records and release database connection
records = await results.fetch(
1000
) # Max neighbour nodes we can handled
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
if current_depth > 1 and len(records) < min_degree:
return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=list(f"{target_id}"),
properties=dict(b_node.properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{node.id}",
target=f"{target_id}",
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else:
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
# Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
# Get the starting node's data
async with self._driver.session(
@@ -889,15 +877,129 @@ class Neo4JStorage(BaseGraphStorage):
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}",
labels=list(f"{node_record['n'].get('entity_id')}"),
properties=dict(node_record["n"].properties),
labels=[node_record["n"].get("entity_id")],
properties=dict(node_record["n"]._properties),
)
finally:
await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node
await traverse(start_node, None, 0)
# Initialize queue for BFS with (node, edge, depth) tuples
# edge is None for the starting node
queue = deque([(start_node, None, 0)])
# True BFS implementation using a queue
while queue and len(visited_nodes) < max_nodes:
# Dequeue the next node to process
current_node, current_edge, current_depth = queue.popleft()
# Skip if already visited or exceeds max depth
if current_node.id in visited_nodes:
continue
if current_depth > max_depth:
logger.debug(
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
)
continue
# Add current node to result
result.nodes.append(current_node)
visited_nodes.add(current_node.id)
# Add edge to result if it exists and not already added
if current_edge and current_edge.id not in visited_edges:
result.edges.append(current_edge)
visited_edges.add(current_edge.id)
# Stop if we've reached the node limit
if len(visited_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
break
# Get all edges and target nodes for the current node (even at max_depth)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=current_node.id)
# Get all records and release database connection
records = await results.fetch(1000) # Max neighbor nodes we can handle
await results.consume() # Ensure results are consumed
# Process all neighbors - capture all edges but only queue unvisited nodes
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=[target_id],
properties=dict(b_node._properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{current_node.id}",
target=f"{target_id}",
properties=dict(rel),
)
# 对source_id和target_id进行排序确保(A,B)和(B,A)被视为同一条边
sorted_pair = tuple(sorted([current_node.id, target_id]))
# 检查是否已存在相同的边(考虑无向性)
if sorted_pair not in visited_edge_pairs:
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
if target_id in visited_nodes or (
target_id not in visited_nodes
and current_depth < max_depth
):
result.edges.append(target_edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
# Only add unvisited nodes to the queue for further expansion
if target_id not in visited_nodes:
# Only add to queue if we're not at max depth yet
if current_depth < max_depth:
# Add node to queue with incremented depth
# Edge is already added to result, so we pass None as edge
queue.append((target_node, None, current_depth + 1))
else:
# At max depth, we've already added the edge but we don't add the node
# This prevents adding nodes beyond max_depth to the result
logger.debug(
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
)
else:
# If target node already exists in result, we don't need to add it again
logger.debug(
f"Node {target_id} already visited, edge added but node not queued"
)
else:
logger.warning(
f"Skipping edge {edge_id} due to missing entity_id on target node"
)
logger.info(
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def get_all_labels(self) -> list[str]:
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
# Method 2: Query compatible with older versions
query = """
MATCH (n)
MATCH (n:base)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This method will delete all nodes and relationships in the Neo4j database.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._driver.session(database=self._DATABASE) as session:
# Delete all nodes and relationships
query = "MATCH (n) DETACH DELETE n"
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
logger.info(
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
)
nx.write_graphml(graph, file_name)
# TODOdeprecated, remove later
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph()
if graph.has_node(node_id):
graph.remove_node(node_id)
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
# TODO: NOT USED
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
nodes: List of node IDs to be deleted
"""
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
result = KnowledgeGraph()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
subgraph = (
graph.copy()
) # Create a copy to avoid modifying the original graph
# Get degrees of all nodes
degrees = dict(graph.degree())
# Sort nodes by degree in descending order and take top max_nodes
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
# Check if graph is truncated
if len(sorted_nodes) > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
)
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
# Create subgraph with the highest degree nodes
subgraph = graph.subgraph(limited_nodes)
else:
# Find nodes with matching node id based on search_mode
nodes_to_explore = []
for n, attr in graph.nodes(data=True):
node_str = str(n)
if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
# Check if node exists
if node_label not in graph:
logger.warning(f"Node {node_label} not found in the graph")
return KnowledgeGraph() # Return empty graph
if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}")
return result
# Use BFS to get nodes
bfs_nodes = []
visited = set()
queue = [(node_label, 0)] # (node, depth) tuple
# Get subgraph using ego_graph from all matching nodes
combined_subgraph = nx.Graph()
for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Breadth-first search
while queue and len(bfs_nodes) < max_nodes:
current, depth = queue.pop(0)
if current not in visited:
visited.add(current)
bfs_nodes.append(current)
# Get start nodes and direct connected nodes
if nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(
combined_subgraph.neighbors(start_node)
)
# Only explore neighbors if we haven't reached max_depth
if depth < max_depth:
# Add neighbor nodes to queue with incremented depth
neighbors = list(graph.neighbors(current))
queue.extend(
[(n, depth + 1) for n in neighbors if n not in visited]
)
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
# Check if graph is truncated - if we still have nodes in the queue
# and we've reached max_nodes, then the graph is truncated
if queue and len(bfs_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
)
subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
if node in start_nodes:
priority = 2
elif node in direct_connected_nodes:
priority = 1
else:
priority = 0
return (priority, degree)
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
:MAX_GRAPH_NODES
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
)
# Create subgraph with BFS discovered nodes
subgraph = graph.subgraph(bfs_nodes)
# Add nodes to result
seen_nodes = set()
seen_edges = set()
for node in subgraph.nodes():
if str(node) in seen_nodes:
continue
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
for edge in subgraph.edges():
source, target = edge
# Esure unique edge_id for undirect graph
if source > target:
if str(source) > str(target):
source, target = target, source
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
return False # Return error
return True
async def drop(self) -> dict[str, str]:
"""Drop all graph data from storage and clean up resources
This method will:
1. Remove the graph storage file if it exists
2. Reset the graph to an empty state
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# delete _client_file_name
if os.path.exists(self._graphml_xml_file):
os.remove(self._graphml_xml_file)
self._graph = nx.Graph()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -8,17 +8,15 @@ import uuid
from ..utils import logger
from ..base import BaseVectorStorage
import configparser
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
import pipmaster as pm
if not pm.is_installed("qdrant-client"):
pm.install("qdrant-client")
from qdrant_client import QdrantClient, models
from qdrant_client import QdrantClient, models # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
def compute_mdhash_id_for_qdrant(
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error searching for prefix '{prefix}': {e}")
return []
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
try:
# Convert to Qdrant compatible ID
qdrant_id = compute_mdhash_id_for_qdrant(id)
# Retrieve the point by ID
result = self._client.retrieve(
collection_name=self.namespace,
ids=[qdrant_id],
with_payload=True,
)
if not result:
return None
return result[0].payload
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
try:
# Convert to Qdrant compatible IDs
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Retrieve the points by IDs
results = self._client.retrieve(
collection_name=self.namespace,
ids=qdrant_ids,
with_payload=True,
)
return [point.payload for point in results]
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will delete all data from the Qdrant collection.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Delete the collection and recreate it
if self._client.collection_exists(self.namespace):
self._client.delete_collection(self.namespace)
# Recreate the collection
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
),
)
logger.info(
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -12,6 +12,7 @@ if not pm.is_installed("redis"):
from redis.asyncio import Redis, ConnectionPool
from redis.exceptions import RedisError, ConnectionError
from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseKVStorage
import json
@@ -121,7 +122,11 @@ class RedisKVStorage(BaseKVStorage):
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass
async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs"""
if not ids:
@@ -138,71 +143,52 @@ class RedisKVStorage(BaseKVStorage):
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis
Args:
modes (list[str]): List of cache mode to be drop from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
await self.delete(modes)
return True
except Exception:
return False
async with self._get_redis_connection() as redis:
result = await redis.delete(f"{self.namespace}:{entity_id}")
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all keys under the current namespace.
if result:
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
async with self._get_redis_connection() as redis:
try:
keys = await redis.keys(f"{self.namespace}:*")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity"""
try:
async with self._get_redis_connection() as redis:
cursor = 0
relation_keys = []
pattern = f"{self.namespace}:*"
while True:
cursor, keys = await redis.scan(cursor, match=pattern)
# Process keys in batches
if keys:
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
try:
data = json.loads(value)
if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key)
except json.JSONDecodeError:
logger.warning(f"Invalid JSON in key {key}")
continue
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
if cursor == 0:
break
# Delete relations in batches
if relation_keys:
# Delete in chunks to avoid too many arguments
chunk_size = 1000
for i in range(0, len(relation_keys), chunk_size):
chunk = relation_keys[i:i + chunk_size]
deleted = await redis.delete(*chunk)
logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
return {"status": "success", "message": f"{deleted_count} keys dropped"}
else:
logger.debug(f"No relations found for entity {entity_name}")
logger.info(f"No keys found to drop in {self.namespace}")
return {"status": "success", "message": "no keys to drop"}
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass

View File

@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from sqlalchemy import create_engine, text
from sqlalchemy import create_engine, text # type: ignore
class TiDB:
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
# Ti handles persistence automatically
pass
async def delete(self, ids: list[str]) -> None:
"""Delete records with specified IDs from the storage.
Args:
ids: List of record IDs to be deleted
"""
if not ids:
return
try:
table_name = namespace_to_table_name(self.namespace)
id_field = namespace_to_id(self.namespace)
if not table_name or not id_field:
logger.error(f"Unknown namespace for deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.info(
f"Successfully deleted {len(ids)} records from {self.namespace}"
)
except Exception as e:
logger.error(f"Error deleting records from {self.namespace}: {e}")
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
bool: True if successful, False otherwise
"""
if not modes:
return False
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return False
if table_name != "LIGHTRAG_LLM_CACHE":
return False
# 构建MySQL风格的IN查询
modes_list = ", ".join([f"'{mode}'" for mode in modes])
sql = f"""
DELETE FROM {table_name}
WHERE workspace = :workspace
AND mode IN ({modes_list})
"""
logger.info(f"Deleting cache by modes: {modes}")
await self.db.execute(sql, {"workspace": self.db.workspace})
return True
except Exception as e:
logger.error(f"Error deleting cache by modes {modes}: {e}")
return False
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@final
@dataclass
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs from the storage.
Args:
ids: List of vector IDs to be deleted
"""
if not ids:
return
table_name = namespace_to_table_name(self.namespace)
id_field = namespace_to_id(self.namespace)
if not table_name or not id_field:
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError
"""Delete an entity by its name from the vector storage.
Args:
entity_name: The name of the entity to delete
"""
try:
# Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace AND name = :entity_name"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError
"""Delete all relations associated with an entity.
Args:
entity_name: The name of the entity whose relations should be deleted
"""
try:
# Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
async def index_done_callback(self) -> None:
# Ti handles persistence automatically
pass
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
# Ti handles persistence automatically
pass
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
drop_sql = """
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
"""
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
async def delete_node(self, node_id: str) -> None:
"""Delete a node and all its related edges
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
""",
# Drop tables
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
}

View File

@@ -13,7 +13,6 @@ import pandas as pd
from lightrag.kg import (
STORAGE_ENV_REQUIREMENTS,
STORAGES,
verify_storage_implementation,
)
@@ -230,6 +229,7 @@ class LightRAG:
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage."""
# TODOdeprecated, remove in the future, use WORKSPACE instead
namespace_prefix: str = field(default="")
"""Prefix for namespacing stored data across different environments."""
@@ -510,36 +510,22 @@ class LightRAG:
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
max_nodes: int = 1000,
) -> KnowledgeGraph:
"""Get knowledge graph for a given label
Args:
node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges
"""
# get params supported by get_knowledge_graph of specified storage
import inspect
storage_params = inspect.signature(
self.chunk_entity_relation_graph.get_knowledge_graph
).parameters
kwargs = {"node_label": node_label, "max_depth": max_depth}
if "min_degree" in storage_params and min_degree > 0:
kwargs["min_degree"] = min_degree
if "inclusive" in storage_params:
kwargs["inclusive"] = inclusive
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label, max_depth, max_nodes
)
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name]
@@ -1449,6 +1435,7 @@ class LightRAG:
loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name))
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def adelete_by_entity(self, entity_name: str) -> None:
try:
await self.entities_vdb.delete_entity(entity_name)
@@ -1486,6 +1473,7 @@ class LightRAG:
self.adelete_by_relation(source_entity, target_entity)
)
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Asynchronously delete a relation between two entities.
@@ -1494,6 +1482,7 @@ class LightRAG:
target_entity: Name of the target entity
"""
try:
# TODO: check if has_edge function works on reverse relation
# Check if the relation exists
edge_exists = await self.chunk_entity_relation_graph.has_edge(
source_entity, target_entity
@@ -1554,6 +1543,7 @@ class LightRAG:
"""
return await self.doc_status.get_docs_by_status(status)
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def adelete_by_doc_id(self, doc_id: str) -> None:
"""Delete a document and all its related data
@@ -1586,6 +1576,8 @@ class LightRAG:
chunk_ids = set(related_chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
# TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
# 3. Before deleting, check the related entities and relationships for these chunks
for chunk_id in chunk_ids:
# Check entities
@@ -1857,24 +1849,6 @@ class LightRAG:
return result
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
async def aclear_cache(self, modes: list[str] | None = None) -> None:
"""Clear cache data from the LLM response cache storage.
@@ -1906,12 +1880,18 @@ class LightRAG:
try:
# Reset the cache storage for specified mode
if modes:
await self.llm_response_cache.delete(modes)
logger.info(f"Cleared cache for modes: {modes}")
success = await self.llm_response_cache.drop_cache_by_modes(modes)
if success:
logger.info(f"Cleared cache for modes: {modes}")
else:
logger.warning(f"Failed to clear cache for modes: {modes}")
else:
# Clear all modes
await self.llm_response_cache.delete(valid_modes)
logger.info("Cleared all cache")
success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
if success:
logger.info("Cleared all cache")
else:
logger.warning("Failed to clear all cache")
await self.llm_response_cache.index_done_callback()
@@ -1922,6 +1902,7 @@ class LightRAG:
"""Synchronous version of aclear_cache."""
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def aedit_entity(
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
) -> dict[str, Any]:
@@ -2134,6 +2115,7 @@ class LightRAG:
]
)
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def aedit_relation(
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
) -> dict[str, Any]:
@@ -2448,6 +2430,7 @@ class LightRAG:
self.acreate_relation(source_entity, target_entity, relation_data)
)
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def amerge_entities(
self,
source_entities: list[str],

View File

@@ -44,6 +44,47 @@ class InvalidResponseError(Exception):
pass
def create_openai_async_client(
api_key: str | None = None,
base_url: str | None = None,
client_configs: dict[str, Any] = None,
) -> AsyncOpenAI:
"""Create an AsyncOpenAI client with the given configuration.
Args:
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
Returns:
An AsyncOpenAI client instance.
"""
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
return AsyncOpenAI(**merged_configs)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -61,29 +102,52 @@ async def openai_complete_if_cache(
token_tracker: Any | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support.
Args:
model: The OpenAI model to use.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
Returns:
The completed text or an async iterator of text chunks if streaming.
Raises:
InvalidResponseError: If the response from OpenAI is invalid or empty.
APIConnectionError: If there is a connection error with the OpenAI API.
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
if history_messages is None:
history_messages = []
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
# Set openai logger level to INFO when VERBOSE_DEBUG is off
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("openai").setLevel(logging.INFO)
openai_async_client = (
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
else AsyncOpenAI(
base_url=base_url, default_headers=default_headers, api_key=api_key
)
# Extract client configuration options
client_configs = kwargs.pop("openai_client_configs", {})
# Create the OpenAI client
openai_async_client = create_openai_async_client(
api_key=api_key, base_url=base_url, client_configs=client_configs
)
# Remove special kwargs that shouldn't be passed to OpenAI
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
# Prepare messages
messages: list[dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -272,21 +336,32 @@ async def openai_embed(
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
client_configs: dict[str, Any] = None,
) -> np.ndarray:
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
"""Generate embeddings for a list of texts using OpenAI's API.
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
else AsyncOpenAI(
base_url=base_url, default_headers=default_headers, api_key=api_key
)
Args:
texts: List of texts to embed.
model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
Returns:
A numpy array of embeddings, one per input text.
Raises:
APIConnectionError: If there is a connection error with the OpenAI API.
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
# Create the OpenAI client
openai_async_client = create_openai_async_client(
api_key=api_key, base_url=base_url, client_configs=client_configs
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)

View File

@@ -26,7 +26,6 @@ from .utils import (
CacheData,
statistic_data,
get_conversation_turns,
verbose_debug,
)
from .base import (
BaseGraphStorage,
@@ -442,6 +441,13 @@ async def extract_entities(
processed_chunks = 0
total_chunks = len(ordered_chunks)
total_entities_count = 0
total_relations_count = 0
# Get lock manager from shared storage
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None
@@ -540,7 +546,7 @@ async def extract_entities(
chunk_key_dp (tuple[str, TextChunkSchema]):
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
"""
nonlocal processed_chunks
nonlocal processed_chunks, total_entities_count, total_relations_count
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
@@ -598,102 +604,74 @@ async def extract_entities(
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return dict(maybe_nodes), dict(maybe_edges)
tasks = [_process_single_content(c) for c in ordered_chunks]
results = await asyncio.gather(*tasks)
# Use graph database lock to ensure atomic merges and updates
chunk_entities_data = []
chunk_relationships_data = []
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
# Ensure that nodes and edges are merged and upserted atomically
async with graph_db_lock:
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
_merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
async with graph_db_lock:
# Process and update entities
for entity_name, entities in maybe_nodes.items():
entity_data = await _merge_nodes_then_upsert(
entity_name, entities, knowledge_graph_inst, global_config
)
for k, v in maybe_edges.items()
]
)
chunk_entities_data.append(entity_data)
if not (all_entities_data or all_relationships_data):
log_message = "Didn't extract any entities and relationships."
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return
# Process and update relationships
for edge_key, edges in maybe_edges.items():
# Ensure edge direction consistency
sorted_edge_key = tuple(sorted(edge_key))
edge_data = await _merge_edges_then_upsert(
sorted_edge_key[0],
sorted_edge_key[1],
edges,
knowledge_graph_inst,
global_config,
)
chunk_relationships_data.append(edge_data)
if not all_entities_data:
log_message = "Didn't extract any entities"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
if not all_relationships_data:
log_message = "Didn't extract any relationships"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update vector database (within the same lock to ensure atomicity)
if entity_vdb is not None and chunk_entities_data:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in chunk_entities_data
}
await entity_vdb.upsert(data_for_vdb)
log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
if relationships_vdb is not None and chunk_relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in chunk_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
# Update counters
total_entities_count += len(chunk_entities_data)
total_relations_count += len(chunk_relationships_data)
# Handle all chunks in parallel
tasks = [_process_single_content(c) for c in ordered_chunks]
await asyncio.gather(*tasks)
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
verbose_debug(
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
)
verbose_debug(f"New relationships:{all_relationships_data}")
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
if relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
async def kg_query(
@@ -720,8 +698,7 @@ async def kg_query(
if cached_response is not None:
return cached_response
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
@@ -817,6 +794,38 @@ async def kg_query(
return response
async def get_keywords_from_query(
query: str,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
Retrieves high-level and low-level keywords for RAG operations.
This function checks if keywords are already provided in query parameters,
and if not, extracts them from the query text using LLM.
Args:
query: The user's query text
query_param: Query parameters that may contain pre-defined keywords
global_config: Global configuration dictionary
hashing_kv: Optional key-value storage for caching results
Returns:
A tuple containing (high_level_keywords, low_level_keywords)
"""
# Check if pre-defined keywords are already provided
if query_param.hl_keywords or query_param.ll_keywords:
return query_param.hl_keywords, query_param.ll_keywords
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv
)
return hl_keywords, ll_keywords
async def extract_keywords_only(
text: str,
param: QueryParam,
@@ -957,8 +966,7 @@ async def mix_kg_vector_query(
# 2. Execute knowledge graph and vector searches in parallel
async def get_kg_context():
try:
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
@@ -1339,7 +1347,9 @@ async def _get_node_data(
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"], t["file_path"]])
text_units_section_list.append(
[i, t["content"], t.get("file_path", "unknown_source")]
)
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context, relations_context, text_units_context
@@ -2043,16 +2053,13 @@ async def query_with_keywords(
Query response or async iterator
"""
# Extract keywords
hl_keywords, ll_keywords = await extract_keywords_only(
text=query,
param=param,
hl_keywords, ll_keywords = await get_keywords_from_query(
query=query,
query_param=param,
global_config=global_config,
hashing_kv=hashing_kv,
)
param.hl_keywords = hl_keywords
param.ll_keywords = ll_keywords
# Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords)

View File

@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
class KnowledgeGraph(BaseModel):
nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = []
is_truncated: bool = False