Merge branch 'main' into main
This commit is contained in:
@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
|
||||
|
||||
```
|
||||
JsonKVStorage JsonFile(默认)
|
||||
MongoKVStorage MogonDB
|
||||
RedisKVStorage Redis
|
||||
TiDBKVStorage TiDB
|
||||
PGKVStorage Postgres
|
||||
OracleKVStorage Oracle
|
||||
RedisKVStorage Redis
|
||||
MongoKVStorage MogonDB
|
||||
```
|
||||
|
||||
* GRAPH_STORAGE 支持的实现名称
|
||||
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
|
||||
```
|
||||
NetworkXStorage NetworkX(默认)
|
||||
Neo4JStorage Neo4J
|
||||
MongoGraphStorage MongoDB
|
||||
TiDBGraphStorage TiDB
|
||||
AGEStorage AGE
|
||||
GremlinStorage Gremlin
|
||||
PGGraphStorage Postgres
|
||||
OracleGraphStorage Postgres
|
||||
AGEStorage AGE
|
||||
```
|
||||
|
||||
* VECTOR_STORAGE 支持的实现名称
|
||||
|
||||
```
|
||||
NanoVectorDBStorage NanoVector(默认)
|
||||
PGVectorStorage Postgres
|
||||
MilvusVectorDBStorge Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
FaissVectorDBStorage Faiss
|
||||
QdrantVectorDBStorage Qdrant
|
||||
OracleVectorDBStorage Oracle
|
||||
MongoVectorDBStorage MongoDB
|
||||
```
|
||||
|
||||
|
@@ -302,11 +302,9 @@ Each storage type have servals implementations:
|
||||
|
||||
```
|
||||
JsonKVStorage JsonFile(default)
|
||||
MongoKVStorage MogonDB
|
||||
RedisKVStorage Redis
|
||||
TiDBKVStorage TiDB
|
||||
PGKVStorage Postgres
|
||||
OracleKVStorage Oracle
|
||||
RedisKVStorage Redis
|
||||
MongoKVStorage MogonDB
|
||||
```
|
||||
|
||||
* GRAPH_STORAGE supported implement-name
|
||||
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
|
||||
```
|
||||
NetworkXStorage NetworkX(defualt)
|
||||
Neo4JStorage Neo4J
|
||||
MongoGraphStorage MongoDB
|
||||
TiDBGraphStorage TiDB
|
||||
AGEStorage AGE
|
||||
GremlinStorage Gremlin
|
||||
PGGraphStorage Postgres
|
||||
OracleGraphStorage Postgres
|
||||
AGEStorage AGE
|
||||
```
|
||||
|
||||
* VECTOR_STORAGE supported implement-name
|
||||
|
||||
```
|
||||
NanoVectorDBStorage NanoVector(default)
|
||||
MilvusVectorDBStorage Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
MilvusVectorDBStorge Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
FaissVectorDBStorage Faiss
|
||||
QdrantVectorDBStorage Qdrant
|
||||
OracleVectorDBStorage Oracle
|
||||
MongoVectorDBStorage MongoDB
|
||||
```
|
||||
|
||||
|
@@ -1 +1 @@
|
||||
__api_version__ = "1.2.8"
|
||||
__api_version__ = "0136"
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .config import global_args
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
|
||||
|
||||
class AuthHandler:
|
||||
def __init__(self):
|
||||
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
|
||||
self.algorithm = "HS256"
|
||||
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
|
||||
self.guest_expire_hours = int(os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2))
|
||||
|
||||
self.secret = global_args.token_secret
|
||||
self.algorithm = global_args.jwt_algorithm
|
||||
self.expire_hours = global_args.token_expire_hours
|
||||
self.guest_expire_hours = global_args.guest_token_expire_hours
|
||||
self.accounts = {}
|
||||
auth_accounts = os.getenv("AUTH_ACCOUNTS")
|
||||
auth_accounts = global_args.auth_accounts
|
||||
if auth_accounts:
|
||||
for account in auth_accounts.split(","):
|
||||
username, password = account.split(":", 1)
|
||||
|
335
lightrag/api/config.py
Normal file
335
lightrag/api/config.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Configs for the LightRAG API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
|
||||
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
|
||||
class DefaultRAGStorageConfig:
|
||||
KV_STORAGE = "JsonKVStorage"
|
||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
default_hosts = {
|
||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||
}
|
||||
return default_hosts.get(
|
||||
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||
) # fallback to ollama if unknown
|
||||
|
||||
|
||||
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
||||
"""
|
||||
Get value from environment variable with type conversion
|
||||
|
||||
Args:
|
||||
env_key (str): Environment variable key
|
||||
default (any): Default value if env variable is not set
|
||||
value_type (type): Type to convert the value to
|
||||
|
||||
Returns:
|
||||
any: Converted value from environment or default
|
||||
"""
|
||||
value = os.getenv(env_key)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if value_type is bool:
|
||||
return value.lower() in ("true", "1", "yes", "t", "on")
|
||||
try:
|
||||
return value_type(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments with environment variable fallback
|
||||
|
||||
Args:
|
||||
is_uvicorn_mode: Whether running under uvicorn mode
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=get_env_value("HOST", "0.0.0.0"),
|
||||
help="Server host (default: from env or 0.0.0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=get_env_value("PORT", 9621, int),
|
||||
help="Server port (default: from env or 9621)",
|
||||
)
|
||||
|
||||
# Directory configuration
|
||||
parser.add_argument(
|
||||
"--working-dir",
|
||||
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
||||
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
default=get_env_value("INPUT_DIR", "./inputs"),
|
||||
help="Directory containing input documents (default: from env or ./inputs)",
|
||||
)
|
||||
|
||||
def timeout_type(value):
|
||||
if value is None:
|
||||
return 150
|
||||
if value is None or value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
default=get_env_value("TIMEOUT", None, timeout_type),
|
||||
type=timeout_type,
|
||||
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
||||
)
|
||||
|
||||
# RAG configuration
|
||||
parser.add_argument(
|
||||
"--max-async",
|
||||
type=int,
|
||||
default=get_env_value("MAX_ASYNC", 4, int),
|
||||
help="Maximum async operations (default: from env or 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=get_env_value("MAX_TOKENS", 32768, int),
|
||||
help="Maximum token size (default: from env or 32768)",
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default=get_env_value("LOG_LEVEL", "INFO"),
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Logging level (default: from env or INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=get_env_value("VERBOSE", False, bool),
|
||||
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
type=str,
|
||||
default=get_env_value("LIGHTRAG_API_KEY", None),
|
||||
help="API key for authentication. This protects lightrag server against unauthorized access",
|
||||
)
|
||||
|
||||
# Optional https parameters
|
||||
parser.add_argument(
|
||||
"--ssl",
|
||||
action="store_true",
|
||||
default=get_env_value("SSL", False, bool),
|
||||
help="Enable HTTPS (default: from env or False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
default=get_env_value("SSL_CERTFILE", None),
|
||||
help="Path to SSL certificate file (required if --ssl is enabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-keyfile",
|
||||
default=get_env_value("SSL_KEYFILE", None),
|
||||
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--history-turns",
|
||||
type=int,
|
||||
default=get_env_value("HISTORY_TURNS", 3, int),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=get_env_value("TOP_K", 60, int),
|
||||
help="Number of most similar results to return (default: from env or 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cosine-threshold",
|
||||
type=float,
|
||||
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||
)
|
||||
|
||||
# Ollama model name
|
||||
parser.add_argument(
|
||||
"--simulated-model-name",
|
||||
type=str,
|
||||
default=get_env_value(
|
||||
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
||||
),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Namespace
|
||||
parser.add_argument(
|
||||
"--namespace-prefix",
|
||||
type=str,
|
||||
default=get_env_value("NAMESPACE_PREFIX", ""),
|
||||
help="Prefix of the namespace",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--auto-scan-at-startup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable automatic scanning when the program starts",
|
||||
)
|
||||
|
||||
# Server workers configuration
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=get_env_value("WORKERS", 1, int),
|
||||
help="Number of worker processes (default: from env or 1)",
|
||||
)
|
||||
|
||||
# LLM and embedding bindings
|
||||
parser.add_argument(
|
||||
"--llm-binding",
|
||||
type=str,
|
||||
default=get_env_value("LLM_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
||||
help="LLM binding type (default: from env or ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-binding",
|
||||
type=str,
|
||||
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "azure_openai"],
|
||||
help="Embedding binding type (default: from env or ollama)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# convert relative path to absolute path
|
||||
args.working_dir = os.path.abspath(args.working_dir)
|
||||
args.input_dir = os.path.abspath(args.input_dir)
|
||||
|
||||
# Inject storage configuration from environment variables
|
||||
args.kv_storage = get_env_value(
|
||||
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||
)
|
||||
args.doc_status_storage = get_env_value(
|
||||
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||
)
|
||||
args.graph_storage = get_env_value(
|
||||
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||
)
|
||||
args.vector_storage = get_env_value(
|
||||
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||
)
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from environment
|
||||
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
|
||||
# Handle openai-ollama special case
|
||||
if args.llm_binding == "openai-ollama":
|
||||
args.llm_binding = "openai"
|
||||
args.embedding_binding = "ollama"
|
||||
|
||||
args.llm_binding_host = get_env_value(
|
||||
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
||||
)
|
||||
args.embedding_binding_host = get_env_value(
|
||||
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
||||
)
|
||||
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
||||
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
||||
|
||||
# Inject model configuration
|
||||
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
||||
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
||||
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
||||
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
||||
|
||||
# Inject chunk configuration
|
||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||
|
||||
# Inject LLM cache configuration
|
||||
args.enable_llm_cache_for_extract = get_env_value(
|
||||
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
||||
)
|
||||
|
||||
# Inject LLM temperature configuration
|
||||
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
|
||||
# Add environment variables that were previously read directly
|
||||
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
|
||||
args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
|
||||
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
|
||||
|
||||
# For JWT Auth
|
||||
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
|
||||
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
|
||||
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
|
||||
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
|
||||
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
||||
|
||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def update_uvicorn_mode_config():
|
||||
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
||||
if global_args.workers > 1:
|
||||
original_workers = global_args.workers
|
||||
global_args.workers = 1
|
||||
# Log warning directly here
|
||||
logging.warning(
|
||||
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
||||
)
|
||||
|
||||
|
||||
global_args = parse_args()
|
@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.api.utils_api import (
|
||||
get_combined_auth_dependency,
|
||||
parse_args,
|
||||
get_default_host,
|
||||
display_splash_screen,
|
||||
check_env_file,
|
||||
)
|
||||
from .config import (
|
||||
global_args,
|
||||
update_uvicorn_mode_config,
|
||||
get_default_host,
|
||||
)
|
||||
import sys
|
||||
from lightrag import LightRAG, __version__ as core_version
|
||||
from lightrag.api import __api_version__
|
||||
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
webui_title = os.getenv("WEBUI_TITLE")
|
||||
webui_description = os.getenv("WEBUI_DESCRIPTION")
|
||||
|
||||
# Initialize config parser
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
@@ -164,10 +171,10 @@ def create_app(args):
|
||||
app = FastAPI(**app_kwargs)
|
||||
|
||||
def get_cors_origins():
|
||||
"""Get allowed origins from environment variable
|
||||
"""Get allowed origins from global_args
|
||||
Returns a list of allowed origins, defaults to ["*"] if not set
|
||||
"""
|
||||
origins_str = os.getenv("CORS_ORIGINS", "*")
|
||||
origins_str = global_args.cors_origins
|
||||
if origins_str == "*":
|
||||
return ["*"]
|
||||
return [origin.strip() for origin in origins_str.split(",")]
|
||||
@@ -315,9 +322,10 @@ def create_app(args):
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
# namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
addon_params={"language": args.summary_language},
|
||||
)
|
||||
else: # azure_openai
|
||||
rag = LightRAG(
|
||||
@@ -345,9 +353,10 @@ def create_app(args):
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
# namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
addon_params={"language": args.summary_language},
|
||||
)
|
||||
|
||||
# Add routes
|
||||
@@ -381,6 +390,8 @@ def create_app(args):
|
||||
"message": "Authentication is disabled. Using guest access.",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -388,6 +399,8 @@ def create_app(args):
|
||||
"auth_mode": "enabled",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
|
||||
@app.post("/login")
|
||||
@@ -404,6 +417,8 @@ def create_app(args):
|
||||
"message": "Authentication is disabled. Using guest access.",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
username = form_data.username
|
||||
if auth_handler.accounts.get(username) != form_data.password:
|
||||
@@ -421,6 +436,8 @@ def create_app(args):
|
||||
"auth_mode": "enabled",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||
@@ -454,10 +471,12 @@ def create_app(args):
|
||||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
},
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"auth_mode": auth_mode,
|
||||
"pipeline_busy": pipeline_status.get("busy", False),
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting health status: {str(e)}")
|
||||
@@ -490,7 +509,7 @@ def create_app(args):
|
||||
def get_application(args=None):
|
||||
"""Factory function for creating the FastAPI application"""
|
||||
if args is None:
|
||||
args = parse_args()
|
||||
args = global_args
|
||||
return create_app(args)
|
||||
|
||||
|
||||
@@ -611,30 +630,31 @@ def main():
|
||||
|
||||
# Configure logging before parsing args
|
||||
configure_logging()
|
||||
|
||||
args = parse_args(is_uvicorn_mode=True)
|
||||
display_splash_screen(args)
|
||||
update_uvicorn_mode_config()
|
||||
display_splash_screen(global_args)
|
||||
|
||||
# Create application instance directly instead of using factory function
|
||||
app = create_app(args)
|
||||
app = create_app(global_args)
|
||||
|
||||
# Start Uvicorn in single process mode
|
||||
uvicorn_config = {
|
||||
"app": app, # Pass application instance directly instead of string path
|
||||
"host": args.host,
|
||||
"port": args.port,
|
||||
"host": global_args.host,
|
||||
"port": global_args.port,
|
||||
"log_config": None, # Disable default config
|
||||
}
|
||||
|
||||
if args.ssl:
|
||||
if global_args.ssl:
|
||||
uvicorn_config.update(
|
||||
{
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
"ssl_keyfile": args.ssl_keyfile,
|
||||
"ssl_certfile": global_args.ssl_certfile,
|
||||
"ssl_keyfile": global_args.ssl_keyfile,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
|
||||
print(
|
||||
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
|
||||
)
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
|
||||
|
@@ -10,16 +10,14 @@ import traceback
|
||||
import pipmaster as pm
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Literal
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.api.utils_api import (
|
||||
get_combined_auth_dependency,
|
||||
global_args,
|
||||
)
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from ..config import global_args
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
@@ -30,7 +28,37 @@ router = APIRouter(
|
||||
temp_prefix = "__tmp__"
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
"""Response model for document scanning operation
|
||||
|
||||
Attributes:
|
||||
status: Status of the scanning operation
|
||||
message: Optional message with additional details
|
||||
"""
|
||||
|
||||
status: Literal["scanning_started"] = Field(
|
||||
description="Status of the scanning operation"
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Additional details about the scanning operation"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "scanning_started",
|
||||
"message": "Scanning process has been initiated in the background",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InsertTextRequest(BaseModel):
|
||||
"""Request model for inserting a single text document
|
||||
|
||||
Attributes:
|
||||
text: The text content to be inserted into the RAG system
|
||||
"""
|
||||
|
||||
text: str = Field(
|
||||
min_length=1,
|
||||
description="The text to insert",
|
||||
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
|
||||
def strip_after(cls, text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "This is a sample text to be inserted into the RAG system."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InsertTextsRequest(BaseModel):
|
||||
"""Request model for inserting multiple text documents
|
||||
|
||||
Attributes:
|
||||
texts: List of text contents to be inserted into the RAG system
|
||||
"""
|
||||
|
||||
texts: list[str] = Field(
|
||||
min_length=1,
|
||||
description="The texts to insert",
|
||||
@@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel):
|
||||
def strip_after(cls, texts: list[str]) -> list[str]:
|
||||
return [text.strip() for text in texts]
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"texts": [
|
||||
"This is the first text to be inserted.",
|
||||
"This is the second text to be inserted.",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InsertResponse(BaseModel):
|
||||
status: str = Field(description="Status of the operation")
|
||||
"""Response model for document insertion operations
|
||||
|
||||
Attributes:
|
||||
status: Status of the operation (success, duplicated, partial_success, failure)
|
||||
message: Detailed message describing the operation result
|
||||
"""
|
||||
|
||||
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
|
||||
description="Status of the operation"
|
||||
)
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "success",
|
||||
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ClearDocumentsResponse(BaseModel):
|
||||
"""Response model for document clearing operation
|
||||
|
||||
Attributes:
|
||||
status: Status of the clear operation
|
||||
message: Detailed message describing the operation result
|
||||
"""
|
||||
|
||||
status: Literal["success", "partial_success", "busy", "fail"] = Field(
|
||||
description="Status of the clear operation"
|
||||
)
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "success",
|
||||
"message": "All documents cleared successfully. Deleted 15 files.",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ClearCacheRequest(BaseModel):
|
||||
"""Request model for clearing cache
|
||||
|
||||
Attributes:
|
||||
modes: Optional list of cache modes to clear
|
||||
"""
|
||||
|
||||
modes: Optional[
|
||||
List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
|
||||
] = Field(
|
||||
default=None,
|
||||
description="Modes of cache to clear. If None, clears all cache.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"example": {"modes": ["default", "naive"]}}
|
||||
|
||||
|
||||
class ClearCacheResponse(BaseModel):
|
||||
"""Response model for cache clearing operation
|
||||
|
||||
Attributes:
|
||||
status: Status of the clear operation
|
||||
message: Detailed message describing the operation result
|
||||
"""
|
||||
|
||||
status: Literal["success", "fail"] = Field(
|
||||
description="Status of the clear operation"
|
||||
)
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "success",
|
||||
"message": "Successfully cleared cache for modes: ['default', 'naive']",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
"""Response model for document status
|
||||
|
||||
Attributes:
|
||||
id: Document identifier
|
||||
content_summary: Summary of document content
|
||||
content_length: Length of document content
|
||||
status: Current processing status
|
||||
created_at: Creation timestamp (ISO format string)
|
||||
updated_at: Last update timestamp (ISO format string)
|
||||
chunks_count: Number of chunks (optional)
|
||||
error: Error message if any (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
file_path: Path to the document file
|
||||
"""
|
||||
|
||||
|
||||
class DocStatusResponse(BaseModel):
|
||||
@staticmethod
|
||||
@@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel):
|
||||
return dt
|
||||
return dt.isoformat()
|
||||
|
||||
"""Response model for document status
|
||||
id: str = Field(description="Document identifier")
|
||||
content_summary: str = Field(description="Summary of document content")
|
||||
content_length: int = Field(description="Length of document content in characters")
|
||||
status: DocStatus = Field(description="Current processing status")
|
||||
created_at: str = Field(description="Creation timestamp (ISO format string)")
|
||||
updated_at: str = Field(description="Last update timestamp (ISO format string)")
|
||||
chunks_count: Optional[int] = Field(
|
||||
default=None, description="Number of chunks the document was split into"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="Error message if processing failed"
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
default=None, description="Additional metadata about the document"
|
||||
)
|
||||
file_path: str = Field(description="Path to the document file")
|
||||
|
||||
Attributes:
|
||||
id: Document identifier
|
||||
content_summary: Summary of document content
|
||||
content_length: Length of document content
|
||||
status: Current processing status
|
||||
created_at: Creation timestamp (ISO format string)
|
||||
updated_at: Last update timestamp (ISO format string)
|
||||
chunks_count: Number of chunks (optional)
|
||||
error: Error message if any (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
"""
|
||||
|
||||
id: str
|
||||
content_summary: str
|
||||
content_length: int
|
||||
status: DocStatus
|
||||
created_at: str
|
||||
updated_at: str
|
||||
chunks_count: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
file_path: str
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"id": "doc_123456",
|
||||
"content_summary": "Research paper on machine learning",
|
||||
"content_length": 15240,
|
||||
"status": "PROCESSED",
|
||||
"created_at": "2025-03-31T12:34:56",
|
||||
"updated_at": "2025-03-31T12:35:30",
|
||||
"chunks_count": 12,
|
||||
"error": None,
|
||||
"metadata": {"author": "John Doe", "year": 2025},
|
||||
"file_path": "research_paper.pdf",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DocsStatusesResponse(BaseModel):
|
||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
||||
"""Response model for document statuses
|
||||
|
||||
Attributes:
|
||||
statuses: Dictionary mapping document status to lists of document status responses
|
||||
"""
|
||||
|
||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary mapping document status to lists of document status responses",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"statuses": {
|
||||
"PENDING": [
|
||||
{
|
||||
"id": "doc_123",
|
||||
"content_summary": "Pending document",
|
||||
"content_length": 5000,
|
||||
"status": "PENDING",
|
||||
"created_at": "2025-03-31T10:00:00",
|
||||
"updated_at": "2025-03-31T10:00:00",
|
||||
"file_path": "pending_doc.pdf",
|
||||
}
|
||||
],
|
||||
"PROCESSED": [
|
||||
{
|
||||
"id": "doc_456",
|
||||
"content_summary": "Processed document",
|
||||
"content_length": 8000,
|
||||
"status": "PROCESSED",
|
||||
"created_at": "2025-03-31T09:00:00",
|
||||
"updated_at": "2025-03-31T09:05:00",
|
||||
"chunks_count": 8,
|
||||
"file_path": "processed_doc.pdf",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PipelineStatusResponse(BaseModel):
|
||||
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
)
|
||||
return False
|
||||
case ".pdf":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
for page in reader.pages:
|
||||
content += page.extract_text() + "\n"
|
||||
case ".docx":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
[paragraph.text for paragraph in doc.paragraphs]
|
||||
)
|
||||
case ".pptx":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
if hasattr(shape, "text"):
|
||||
content += shape.text + "\n"
|
||||
case ".xlsx":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
|
||||
|
||||
# TODO: deprecate after /insert_file is removed
|
||||
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||
"""Save the uploaded file to a temporary location
|
||||
|
||||
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
if not new_files:
|
||||
return
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from global_args["main_args"]
|
||||
max_parallel = global_args["main_args"].max_parallel_insert
|
||||
# Get MAX_PARALLEL_INSERT from global_args
|
||||
max_parallel = global_args.max_parallel_insert
|
||||
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
|
||||
batch_size = 2 * max_parallel
|
||||
|
||||
@@ -509,7 +704,9 @@ def create_document_routes(
|
||||
# Create combined auth dependency for document routes
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
||||
@router.post(
|
||||
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Trigger the scanning process for new documents.
|
||||
@@ -519,13 +716,18 @@ def create_document_routes(
|
||||
that fact.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the scanning status
|
||||
ScanResponse: A response object containing the scanning status
|
||||
"""
|
||||
# Start the scanning process in the background
|
||||
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
||||
return {"status": "scanning_started"}
|
||||
return ScanResponse(
|
||||
status="scanning_started",
|
||||
message="Scanning process has been initiated in the background",
|
||||
)
|
||||
|
||||
@router.post("/upload", dependencies=[Depends(combined_auth)])
|
||||
@router.post(
|
||||
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def upload_to_input_dir(
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
):
|
||||
@@ -645,6 +847,7 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# TODO: deprecated, use /upload instead
|
||||
@router.post(
|
||||
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
@@ -688,6 +891,7 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# TODO: deprecated, use /upload instead
|
||||
@router.post(
|
||||
"/file_batch",
|
||||
response_model=InsertResponse,
|
||||
@@ -752,32 +956,186 @@ def create_document_routes(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete(
|
||||
"", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
"", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def clear_documents():
|
||||
"""
|
||||
Clear all documents from the RAG system.
|
||||
|
||||
This endpoint deletes all text chunks, entities vector database, and relationships
|
||||
vector database, effectively clearing all documents from the RAG system.
|
||||
This endpoint deletes all documents, entities, relationships, and files from the system.
|
||||
It uses the storage drop methods to properly clean up all data and removes all files
|
||||
from the input directory.
|
||||
|
||||
Returns:
|
||||
InsertResponse: A response object containing the status and message.
|
||||
ClearDocumentsResponse: A response object containing the status and message.
|
||||
- status="success": All documents and files were successfully cleared.
|
||||
- status="partial_success": Document clear job exit with some errors.
|
||||
- status="busy": Operation could not be completed because the pipeline is busy.
|
||||
- status="fail": All storage drop operations failed, with message
|
||||
- message: Detailed information about the operation results, including counts
|
||||
of deleted files and any errors encountered.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during the clearing process (500).
|
||||
HTTPException: Raised when a serious error occurs during the clearing process,
|
||||
with status code 500 and error details in the detail field.
|
||||
"""
|
||||
try:
|
||||
rag.text_chunks = []
|
||||
rag.entities_vdb = None
|
||||
rag.relationships_vdb = None
|
||||
return InsertResponse(
|
||||
status="success", message="All documents cleared successfully"
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_pipeline_status_lock,
|
||||
)
|
||||
|
||||
# Get pipeline status and lock
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
pipeline_status_lock = get_pipeline_status_lock()
|
||||
|
||||
# Check and set status with lock
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("busy", False):
|
||||
return ClearDocumentsResponse(
|
||||
status="busy",
|
||||
message="Cannot clear documents while pipeline is busy",
|
||||
)
|
||||
# Set busy to true
|
||||
pipeline_status.update(
|
||||
{
|
||||
"busy": True,
|
||||
"job_name": "Clearing Documents",
|
||||
"job_start": datetime.now().isoformat(),
|
||||
"docs": 0,
|
||||
"batchs": 0,
|
||||
"cur_batch": 0,
|
||||
"request_pending": False, # Clear any previous request
|
||||
"latest_message": "Starting document clearing process",
|
||||
}
|
||||
)
|
||||
# Cleaning history_messages without breaking it as a shared list object
|
||||
del pipeline_status["history_messages"][:]
|
||||
pipeline_status["history_messages"].append(
|
||||
"Starting document clearing process"
|
||||
)
|
||||
|
||||
try:
|
||||
# Use drop method to clear all data
|
||||
drop_tasks = []
|
||||
storages = [
|
||||
rag.text_chunks,
|
||||
rag.full_docs,
|
||||
rag.entities_vdb,
|
||||
rag.relationships_vdb,
|
||||
rag.chunks_vdb,
|
||||
rag.chunk_entity_relation_graph,
|
||||
rag.doc_status,
|
||||
]
|
||||
|
||||
# Log storage drop start
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(
|
||||
"Starting to drop storage components"
|
||||
)
|
||||
|
||||
for storage in storages:
|
||||
if storage is not None:
|
||||
drop_tasks.append(storage.drop())
|
||||
|
||||
# Wait for all drop tasks to complete
|
||||
drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
|
||||
|
||||
# Check for errors and log results
|
||||
errors = []
|
||||
storage_success_count = 0
|
||||
storage_error_count = 0
|
||||
|
||||
for i, result in enumerate(drop_results):
|
||||
storage_name = storages[i].__class__.__name__
|
||||
if isinstance(result, Exception):
|
||||
error_msg = f"Error dropping {storage_name}: {str(result)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
storage_error_count += 1
|
||||
else:
|
||||
logger.info(f"Successfully dropped {storage_name}")
|
||||
storage_success_count += 1
|
||||
|
||||
# Log storage drop results
|
||||
if "history_messages" in pipeline_status:
|
||||
if storage_error_count > 0:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
|
||||
)
|
||||
else:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Successfully dropped all {storage_success_count} storage components"
|
||||
)
|
||||
|
||||
# If all storage operations failed, return error status and don't proceed with file deletion
|
||||
if storage_success_count == 0 and storage_error_count > 0:
|
||||
error_message = "All storage drop operations failed. Aborting document clearing process."
|
||||
logger.error(error_message)
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(error_message)
|
||||
return ClearDocumentsResponse(status="fail", message=error_message)
|
||||
|
||||
# Log file deletion start
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(
|
||||
"Starting to delete files in input directory"
|
||||
)
|
||||
|
||||
# Delete all files in input_dir
|
||||
deleted_files_count = 0
|
||||
file_errors_count = 0
|
||||
|
||||
for file_path in doc_manager.input_dir.glob("**/*"):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file {file_path}: {str(e)}")
|
||||
file_errors_count += 1
|
||||
|
||||
# Log file deletion results
|
||||
if "history_messages" in pipeline_status:
|
||||
if file_errors_count > 0:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Deleted {deleted_files_count} files with {file_errors_count} errors"
|
||||
)
|
||||
errors.append(f"Failed to delete {file_errors_count} files")
|
||||
else:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Successfully deleted {deleted_files_count} files"
|
||||
)
|
||||
|
||||
# Prepare final result message
|
||||
final_message = ""
|
||||
if errors:
|
||||
final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
|
||||
status = "partial_success"
|
||||
else:
|
||||
final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
|
||||
status = "success"
|
||||
|
||||
# Log final result
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(final_message)
|
||||
|
||||
# Return response based on results
|
||||
return ClearDocumentsResponse(status=status, message=final_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error DELETE /documents: {str(e)}")
|
||||
error_msg = f"Error clearing documents: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error(traceback.format_exc())
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(error_msg)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
# Reset busy status after completion
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["busy"] = False
|
||||
completion_msg = "Document clearing process completed"
|
||||
pipeline_status["latest_message"] = completion_msg
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(completion_msg)
|
||||
|
||||
@router.get(
|
||||
"/pipeline_status",
|
||||
@@ -850,7 +1208,9 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("", dependencies=[Depends(combined_auth)])
|
||||
@router.get(
|
||||
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def documents() -> DocsStatusesResponse:
|
||||
"""
|
||||
Get the status of all documents in the system.
|
||||
@@ -908,4 +1268,57 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
"/clear_cache",
|
||||
response_model=ClearCacheResponse,
|
||||
dependencies=[Depends(combined_auth)],
|
||||
)
|
||||
async def clear_cache(request: ClearCacheRequest):
|
||||
"""
|
||||
Clear cache data from the LLM response cache storage.
|
||||
|
||||
This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
|
||||
Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
|
||||
- "default" represents extraction cache.
|
||||
- Other modes correspond to different query modes.
|
||||
|
||||
Args:
|
||||
request (ClearCacheRequest): The request body containing optional modes to clear.
|
||||
|
||||
Returns:
|
||||
ClearCacheResponse: A response object containing the status and message.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
|
||||
"""
|
||||
try:
|
||||
# Validate modes if provided
|
||||
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
|
||||
if request.modes and not all(mode in valid_modes for mode in request.modes):
|
||||
invalid_modes = [
|
||||
mode for mode in request.modes if mode not in valid_modes
|
||||
]
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
|
||||
)
|
||||
|
||||
# Call the aclear_cache method
|
||||
await rag.aclear_cache(request.modes)
|
||||
|
||||
# Prepare success message
|
||||
if request.modes:
|
||||
message = f"Successfully cleared cache for modes: {request.modes}"
|
||||
else:
|
||||
message = "Successfully cleared all cache"
|
||||
|
||||
return ClearCacheResponse(status="success", message=message)
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing cache: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return router
|
||||
|
@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from ..utils_api import get_combined_auth_dependency
|
||||
|
||||
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
||||
async def get_knowledge_graph(
|
||||
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||
label: str = Query(..., description="Label to get knowledge graph for"),
|
||||
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
||||
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
|
||||
):
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||
1. Hops(path) to the staring node take precedence
|
||||
2. Followed by the degree of the nodes
|
||||
|
||||
Args:
|
||||
label (str): Label to get knowledge graph for
|
||||
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
|
||||
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
|
||||
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
|
||||
label (str): Label of the starting node
|
||||
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Knowledge graph for label
|
||||
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
return await rag.get_knowledge_graph(
|
||||
node_label=label,
|
||||
max_depth=max_depth,
|
||||
inclusive=inclusive,
|
||||
min_degree=min_degree,
|
||||
max_nodes=max_nodes,
|
||||
)
|
||||
|
||||
return router
|
||||
|
@@ -7,14 +7,9 @@ import os
|
||||
import sys
|
||||
import signal
|
||||
import pipmaster as pm
|
||||
from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file
|
||||
from lightrag.api.utils_api import display_splash_screen, check_env_file
|
||||
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
from .config import global_args
|
||||
|
||||
|
||||
def check_and_install_dependencies():
|
||||
@@ -59,20 +54,17 @@ def main():
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
||||
|
||||
# Parse all arguments using parse_args
|
||||
args = parse_args(is_uvicorn_mode=False)
|
||||
|
||||
# Display startup information
|
||||
display_splash_screen(args)
|
||||
display_splash_screen(global_args)
|
||||
|
||||
print("🚀 Starting LightRAG with Gunicorn")
|
||||
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
|
||||
print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
|
||||
print("🔍 Preloading app: Enabled")
|
||||
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
||||
print("\n\n" + "=" * 80)
|
||||
print("MAIN PROCESS INITIALIZATION")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print(f"Workers setting: {args.workers}")
|
||||
print(f"Workers setting: {global_args.workers}")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
# Import Gunicorn's StandaloneApplication
|
||||
@@ -128,31 +120,43 @@ def main():
|
||||
|
||||
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
||||
gunicorn_config.workers = (
|
||||
args.workers if args.workers else int(os.getenv("WORKERS", 1))
|
||||
global_args.workers
|
||||
if global_args.workers
|
||||
else int(os.getenv("WORKERS", 1))
|
||||
)
|
||||
|
||||
# Bind configuration prioritizes command line arguments
|
||||
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
|
||||
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
|
||||
host = (
|
||||
global_args.host
|
||||
if global_args.host != "0.0.0.0"
|
||||
else os.getenv("HOST", "0.0.0.0")
|
||||
)
|
||||
port = (
|
||||
global_args.port
|
||||
if global_args.port != 9621
|
||||
else int(os.getenv("PORT", 9621))
|
||||
)
|
||||
gunicorn_config.bind = f"{host}:{port}"
|
||||
|
||||
# Log level configuration prioritizes command line arguments
|
||||
gunicorn_config.loglevel = (
|
||||
args.log_level.lower()
|
||||
if args.log_level
|
||||
global_args.log_level.lower()
|
||||
if global_args.log_level
|
||||
else os.getenv("LOG_LEVEL", "info")
|
||||
)
|
||||
|
||||
# Timeout configuration prioritizes command line arguments
|
||||
gunicorn_config.timeout = (
|
||||
args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2))
|
||||
global_args.timeout
|
||||
if global_args.timeout * 2
|
||||
else int(os.getenv("TIMEOUT", 150 * 2))
|
||||
)
|
||||
|
||||
# Keepalive configuration
|
||||
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
||||
|
||||
# SSL configuration prioritizes command line arguments
|
||||
if args.ssl or os.getenv("SSL", "").lower() in (
|
||||
if global_args.ssl or os.getenv("SSL", "").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
@@ -160,12 +164,14 @@ def main():
|
||||
"on",
|
||||
):
|
||||
gunicorn_config.certfile = (
|
||||
args.ssl_certfile
|
||||
if args.ssl_certfile
|
||||
global_args.ssl_certfile
|
||||
if global_args.ssl_certfile
|
||||
else os.getenv("SSL_CERTFILE")
|
||||
)
|
||||
gunicorn_config.keyfile = (
|
||||
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
|
||||
global_args.ssl_keyfile
|
||||
if global_args.ssl_keyfile
|
||||
else os.getenv("SSL_KEYFILE")
|
||||
)
|
||||
|
||||
# Set configuration options from the module
|
||||
@@ -190,13 +196,13 @@ def main():
|
||||
# Import the application
|
||||
from lightrag.api.lightrag_server import get_application
|
||||
|
||||
return get_application(args)
|
||||
return get_application(global_args)
|
||||
|
||||
# Create the application
|
||||
app = GunicornApp("")
|
||||
|
||||
# Force workers to be an integer and greater than 1 for multi-process mode
|
||||
workers_count = int(args.workers)
|
||||
workers_count = int(global_args.workers)
|
||||
if workers_count > 1:
|
||||
# Set a flag to indicate we're in the main process
|
||||
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
||||
|
@@ -7,15 +7,13 @@ import argparse
|
||||
from typing import Optional, List, Tuple
|
||||
import sys
|
||||
from ascii_colors import ASCIIColors
|
||||
import logging
|
||||
from lightrag.api import __api_version__ as api_version
|
||||
from lightrag import __version__ as core_version
|
||||
from fastapi import HTTPException, Security, Request, status
|
||||
from dotenv import load_dotenv
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from .auth import auth_handler
|
||||
from ..prompt import PROMPTS
|
||||
from .config import ollama_server_infos, global_args
|
||||
|
||||
|
||||
def check_env_file():
|
||||
@@ -36,16 +34,8 @@ def check_env_file():
|
||||
return True
|
||||
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
global_args = {"main_args": None}
|
||||
|
||||
# Get whitelist paths from environment variable, only once during initialization
|
||||
default_whitelist = "/health,/api/*"
|
||||
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
|
||||
# Get whitelist paths from global_args, only once during initialization
|
||||
whitelist_paths = global_args.whitelist_paths.split(",")
|
||||
|
||||
# Pre-compile path matching patterns
|
||||
whitelist_patterns: List[Tuple[str, bool]] = []
|
||||
@@ -63,19 +53,6 @@ for path in whitelist_paths:
|
||||
auth_configured = bool(auth_handler.accounts)
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
|
||||
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
|
||||
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
"""
|
||||
Create a combined authentication dependency that implements authentication logic
|
||||
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
return combined_dependency
|
||||
|
||||
|
||||
class DefaultRAGStorageConfig:
|
||||
KV_STORAGE = "JsonKVStorage"
|
||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
default_hosts = {
|
||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||
}
|
||||
return default_hosts.get(
|
||||
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||
) # fallback to ollama if unknown
|
||||
|
||||
|
||||
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
||||
"""
|
||||
Get value from environment variable with type conversion
|
||||
|
||||
Args:
|
||||
env_key (str): Environment variable key
|
||||
default (any): Default value if env variable is not set
|
||||
value_type (type): Type to convert the value to
|
||||
|
||||
Returns:
|
||||
any: Converted value from environment or default
|
||||
"""
|
||||
value = os.getenv(env_key)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if value_type is bool:
|
||||
return value.lower() in ("true", "1", "yes", "t", "on")
|
||||
try:
|
||||
return value_type(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments with environment variable fallback
|
||||
|
||||
Args:
|
||||
is_uvicorn_mode: Whether running under uvicorn mode
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=get_env_value("HOST", "0.0.0.0"),
|
||||
help="Server host (default: from env or 0.0.0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=get_env_value("PORT", 9621, int),
|
||||
help="Server port (default: from env or 9621)",
|
||||
)
|
||||
|
||||
# Directory configuration
|
||||
parser.add_argument(
|
||||
"--working-dir",
|
||||
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
||||
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
default=get_env_value("INPUT_DIR", "./inputs"),
|
||||
help="Directory containing input documents (default: from env or ./inputs)",
|
||||
)
|
||||
|
||||
def timeout_type(value):
|
||||
if value is None:
|
||||
return 150
|
||||
if value is None or value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
default=get_env_value("TIMEOUT", None, timeout_type),
|
||||
type=timeout_type,
|
||||
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
||||
)
|
||||
|
||||
# RAG configuration
|
||||
parser.add_argument(
|
||||
"--max-async",
|
||||
type=int,
|
||||
default=get_env_value("MAX_ASYNC", 4, int),
|
||||
help="Maximum async operations (default: from env or 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=get_env_value("MAX_TOKENS", 32768, int),
|
||||
help="Maximum token size (default: from env or 32768)",
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default=get_env_value("LOG_LEVEL", "INFO"),
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Logging level (default: from env or INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=get_env_value("VERBOSE", False, bool),
|
||||
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
type=str,
|
||||
default=get_env_value("LIGHTRAG_API_KEY", None),
|
||||
help="API key for authentication. This protects lightrag server against unauthorized access",
|
||||
)
|
||||
|
||||
# Optional https parameters
|
||||
parser.add_argument(
|
||||
"--ssl",
|
||||
action="store_true",
|
||||
default=get_env_value("SSL", False, bool),
|
||||
help="Enable HTTPS (default: from env or False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
default=get_env_value("SSL_CERTFILE", None),
|
||||
help="Path to SSL certificate file (required if --ssl is enabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-keyfile",
|
||||
default=get_env_value("SSL_KEYFILE", None),
|
||||
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--history-turns",
|
||||
type=int,
|
||||
default=get_env_value("HISTORY_TURNS", 3, int),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=get_env_value("TOP_K", 60, int),
|
||||
help="Number of most similar results to return (default: from env or 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cosine-threshold",
|
||||
type=float,
|
||||
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||
)
|
||||
|
||||
# Ollama model name
|
||||
parser.add_argument(
|
||||
"--simulated-model-name",
|
||||
type=str,
|
||||
default=get_env_value(
|
||||
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
||||
),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Namespace
|
||||
parser.add_argument(
|
||||
"--namespace-prefix",
|
||||
type=str,
|
||||
default=get_env_value("NAMESPACE_PREFIX", ""),
|
||||
help="Prefix of the namespace",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--auto-scan-at-startup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable automatic scanning when the program starts",
|
||||
)
|
||||
|
||||
# Server workers configuration
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=get_env_value("WORKERS", 1, int),
|
||||
help="Number of worker processes (default: from env or 1)",
|
||||
)
|
||||
|
||||
# LLM and embedding bindings
|
||||
parser.add_argument(
|
||||
"--llm-binding",
|
||||
type=str,
|
||||
default=get_env_value("LLM_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
||||
help="LLM binding type (default: from env or ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-binding",
|
||||
type=str,
|
||||
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "azure_openai"],
|
||||
help="Embedding binding type (default: from env or ollama)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
||||
if is_uvicorn_mode and args.workers > 1:
|
||||
original_workers = args.workers
|
||||
args.workers = 1
|
||||
# Log warning directly here
|
||||
logging.warning(
|
||||
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
||||
)
|
||||
|
||||
# convert relative path to absolute path
|
||||
args.working_dir = os.path.abspath(args.working_dir)
|
||||
args.input_dir = os.path.abspath(args.input_dir)
|
||||
|
||||
# Inject storage configuration from environment variables
|
||||
args.kv_storage = get_env_value(
|
||||
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||
)
|
||||
args.doc_status_storage = get_env_value(
|
||||
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||
)
|
||||
args.graph_storage = get_env_value(
|
||||
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||
)
|
||||
args.vector_storage = get_env_value(
|
||||
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||
)
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from environment
|
||||
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
|
||||
# Handle openai-ollama special case
|
||||
if args.llm_binding == "openai-ollama":
|
||||
args.llm_binding = "openai"
|
||||
args.embedding_binding = "ollama"
|
||||
|
||||
args.llm_binding_host = get_env_value(
|
||||
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
||||
)
|
||||
args.embedding_binding_host = get_env_value(
|
||||
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
||||
)
|
||||
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
||||
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
||||
|
||||
# Inject model configuration
|
||||
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
||||
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
||||
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
||||
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
||||
|
||||
# Inject chunk configuration
|
||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||
|
||||
# Inject LLM cache configuration
|
||||
args.enable_llm_cache_for_extract = get_env_value(
|
||||
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
||||
)
|
||||
|
||||
# Inject LLM temperature configuration
|
||||
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
|
||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||
|
||||
global_args["main_args"] = args
|
||||
return args
|
||||
|
||||
|
||||
def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
"""
|
||||
Display a colorful splash screen showing LightRAG server configuration
|
||||
@@ -489,7 +173,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
# Banner
|
||||
ASCIIColors.cyan(f"""
|
||||
╔══════════════════════════════════════════════════════════════╗
|
||||
║ 🚀 LightRAG Server v{core_version}/{api_version} ║
|
||||
║ 🚀 LightRAG Server v{core_version}/{api_version} ║
|
||||
║ Fast, Lightweight RAG Server Implementation ║
|
||||
╚══════════════════════════════════════════════════════════════╝
|
||||
""")
|
||||
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.white(" ├─ Workers: ", end="")
|
||||
ASCIIColors.yellow(f"{args.workers}")
|
||||
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
||||
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
|
||||
ASCIIColors.yellow(f"{args.cors_origins}")
|
||||
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
||||
ASCIIColors.yellow(f"{args.ssl}")
|
||||
if args.ssl:
|
||||
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.verbose}")
|
||||
ASCIIColors.white(" ├─ History Turns: ", end="")
|
||||
ASCIIColors.yellow(f"{args.history_turns}")
|
||||
ASCIIColors.white(" └─ API Key: ", end="")
|
||||
ASCIIColors.white(" ├─ API Key: ", end="")
|
||||
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
||||
ASCIIColors.white(" └─ JWT Auth: ", end="")
|
||||
ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
|
||||
|
||||
# Directory Configuration
|
||||
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
||||
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.embedding_dim}")
|
||||
|
||||
# RAG Configuration
|
||||
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
|
||||
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
||||
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
||||
ASCIIColors.yellow(f"{summary_language}")
|
||||
ASCIIColors.yellow(f"{args.summary_language}")
|
||||
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
||||
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
||||
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
protocol = "https" if args.ssl else "http"
|
||||
if args.host == "0.0.0.0":
|
||||
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
||||
ASCIIColors.white(" ├─ Local Access: ", end="")
|
||||
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
|
||||
ASCIIColors.white(" ├─ Remote Access: ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
|
||||
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
|
||||
ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="")
|
||||
ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
|
||||
ASCIIColors.white(" └─ WebUI (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
|
||||
|
||||
ASCIIColors.yellow("\n📝 Note:")
|
||||
ASCIIColors.white(""" Since the server is running on 0.0.0.0:
|
||||
ASCIIColors.magenta("\n📝 Note:")
|
||||
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
|
||||
- Use 'localhost' or '127.0.0.1' for local access
|
||||
- Use your machine's IP address for remote access
|
||||
- To find your IP address:
|
||||
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
else:
|
||||
base_url = f"{protocol}://{args.host}:{args.port}"
|
||||
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
||||
ASCIIColors.white(" ├─ Base URL: ", end="")
|
||||
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
||||
ASCIIColors.yellow(f"{base_url}")
|
||||
ASCIIColors.white(" ├─ API Documentation: ", end="")
|
||||
ASCIIColors.yellow(f"{base_url}/docs")
|
||||
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
|
||||
ASCIIColors.yellow(f"{base_url}/redoc")
|
||||
|
||||
# Usage Examples
|
||||
ASCIIColors.magenta("\n📚 Quick Start Guide:")
|
||||
ASCIIColors.cyan("""
|
||||
1. Access the Swagger UI:
|
||||
Open your browser and navigate to the API documentation URL above
|
||||
|
||||
2. API Authentication:""")
|
||||
if args.key:
|
||||
ASCIIColors.cyan(""" Add the following header to your requests:
|
||||
X-API-Key: <your-api-key>
|
||||
""")
|
||||
else:
|
||||
ASCIIColors.cyan(" No authentication required\n")
|
||||
|
||||
ASCIIColors.cyan(""" 3. Basic Operations:
|
||||
- POST /upload_document: Upload new documents to RAG
|
||||
- POST /query: Query your document collection
|
||||
|
||||
4. Monitor the server:
|
||||
- Check server logs for detailed operation information
|
||||
- Use healthcheck endpoint: GET /health
|
||||
""")
|
||||
|
||||
# Security Notice
|
||||
if args.key:
|
||||
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
||||
ASCIIColors.white(""" API Key authentication is enabled.
|
||||
Make sure to include the X-API-Key header in all your requests.
|
||||
""")
|
||||
if args.auth_accounts:
|
||||
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
||||
ASCIIColors.white(""" JWT authentication is enabled.
|
||||
Make sure to login before making the request, and include the 'Authorization' in the header.
|
||||
""")
|
||||
|
||||
# Ensure splash output flush to system log
|
||||
sys.stdout.flush()
|
||||
|
1
lightrag/api/webui/assets/index-CD5HxTy1.css
generated
1
lightrag/api/webui/assets/index-CD5HxTy1.css
generated
File diff suppressed because one or more lines are too long
1345
lightrag/api/webui/assets/index-Cma7xY0-.js
generated
Normal file
1345
lightrag/api/webui/assets/index-Cma7xY0-.js
generated
Normal file
File diff suppressed because one or more lines are too long
1
lightrag/api/webui/assets/index-QU59h9JG.css
generated
Normal file
1
lightrag/api/webui/assets/index-QU59h9JG.css
generated
Normal file
File diff suppressed because one or more lines are too long
1321
lightrag/api/webui/assets/index-raheqJeu.js
generated
1321
lightrag/api/webui/assets/index-raheqJeu.js
generated
File diff suppressed because one or more lines are too long
4
lightrag/api/webui/index.html
generated
4
lightrag/api/webui/index.html
generated
@@ -8,8 +8,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lightrag</title>
|
||||
<script type="module" crossorigin src="/webui/assets/index-raheqJeu.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-CD5HxTy1.css">
|
||||
<script type="module" crossorigin src="/webui/assets/index-Cma7xY0-.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-QU59h9JG.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
131
lightrag/base.py
131
lightrag/base.py
@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
|
||||
async def index_done_callback(self) -> None:
|
||||
"""Commit the storage operations after indexing"""
|
||||
|
||||
@abstractmethod
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
|
||||
This abstract method defines the contract for dropping all data from a storage implementation.
|
||||
Each storage type must implement this method to:
|
||||
1. Clear all data from memory and/or external storage
|
||||
2. Remove any associated storage files if applicable
|
||||
3. Reset the storage to its initial state
|
||||
4. Handle cleanup of any resources
|
||||
5. Notify other processes if necessary
|
||||
6. This action should persistent the data to disk immediately.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message with the following format:
|
||||
{
|
||||
"status": str, # "success" or "error"
|
||||
"message": str # "data dropped" on success, error details on failure
|
||||
}
|
||||
|
||||
Implementation specific:
|
||||
- On success: return {"status": "success", "message": "data dropped"}
|
||||
- On failure: return {"status": "error", "message": "<error details>"}
|
||||
- If not supported: return {"status": "error", "message": "unsupported"}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Insert or update vectors in the storage."""
|
||||
"""Insert or update vectors in the storage.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name."""
|
||||
"""Delete a single entity by its name.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity."""
|
||||
"""Delete relations for a given entity.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(StorageNameSpace, ABC):
|
||||
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Upsert data"""
|
||||
"""Upsert data
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed, or the cache mode is not supported
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get an edge by its source and target node ids."""
|
||||
"""Get node by its label identifier, return only node properties"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
"""Get all edges connected to a node."""
|
||||
"""Get edge properties between two nodes"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""Delete a node from the graph."""
|
||||
"""Delete a node from the graph.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 3
|
||||
self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
|
||||
) -> KnowledgeGraph:
|
||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node,* means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
|
||||
|
||||
class DocStatus(str, Enum):
|
||||
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific status"""
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Drop cache is not supported for Doc Status storage"""
|
||||
return False
|
||||
|
||||
|
||||
class StoragesStatus(str, Enum):
|
||||
"""Storages status"""
|
||||
|
@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
"MongoKVStorage",
|
||||
# "TiDBKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
# "AGEStorage",
|
||||
# "MongoGraphStorage",
|
||||
# "TiDBGraphStorage",
|
||||
# "GremlinStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
# "TiDBVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
# "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
@@ -112,9 +90,6 @@ STORAGES = {
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
@@ -122,14 +97,14 @@ STORAGES = {
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
# "TiDBKVStorage": ".kg.tidb_impl",
|
||||
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
# "TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
# "GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
|
@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
import psycopg # type: ignore
|
||||
from psycopg.rows import namedtuple_row # type: ignore
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
|
||||
async def index_done_callback(self) -> None:
|
||||
# AGES handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all nodes and relationships in the graph.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
query = """
|
||||
MATCH (n)
|
||||
DETACH DELETE n
|
||||
"""
|
||||
await self._query(query)
|
||||
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
@@ -10,8 +11,8 @@ import pipmaster as pm
|
||||
if not pm.is_installed("chromadb"):
|
||||
pm.install("chromadb")
|
||||
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
from chromadb import HttpClient, PersistentClient # type: ignore
|
||||
from chromadb.config import Settings # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all documents from the ChromaDB collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Get all IDs in the collection
|
||||
result = self._collection.get(include=[])
|
||||
if result and result["ids"] and len(result["ids"]) > 0:
|
||||
# Delete all documents
|
||||
self._collection.delete(ids=result["ids"])
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -11,16 +11,20 @@ import pipmaster as pm
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
from lightrag.base import BaseVectorStorage
|
||||
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
import faiss # type: ignore
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
)
|
||||
|
||||
import faiss # type: ignore
|
||||
|
||||
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
|
||||
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
|
||||
|
||||
if not pm.is_installed(FAISS_PACKAGE):
|
||||
pm.install(FAISS_PACKAGE)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
async def delete(self, ids: list[str]):
|
||||
"""
|
||||
Delete vectors for the provided custom IDs.
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||
to_remove = []
|
||||
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||
await self.delete([entity_id])
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Delete relations for a given entity by scanning metadata.
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.debug(f"Searching relations for entity {entity_name}")
|
||||
relations = []
|
||||
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
results.append({**metadata, "id": metadata.get("__id__")})
|
||||
|
||||
return results
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method will remove all vectors from the Faiss index and delete the storage files.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# Reset the index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
|
||||
# Remove storage files if they exist
|
||||
if os.path.exists(self._faiss_index_file):
|
||||
os.remove(self._faiss_index_file)
|
||||
if os.path.exists(self._meta_file):
|
||||
os.remove(self._meta_file)
|
||||
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
|
||||
# Notify other processes
|
||||
await set_all_update_flags(self.namespace)
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
|
||||
if not pm.is_installed("gremlinpython"):
|
||||
pm.install("gremlinpython")
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
from gremlin_python.driver import client, serializer # type: ignore
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
|
||||
from gremlin_python.driver.protocol import GremlinServerError # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge deletion: {str(e)}")
|
||||
raise
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all nodes and relationships in the graph.
|
||||
|
||||
This function deletes all nodes with the specified graph name property,
|
||||
which automatically removes all associated edges.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.drop()
|
||||
"""
|
||||
await self._query(query)
|
||||
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
await clear_all_update_flags(self.namespace)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
async with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
async with self._storage_lock:
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
async def delete(self, doc_ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
any_deleted = False
|
||||
for doc_id in doc_ids:
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all document status data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Clear all document status data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
any_deleted = False
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of cache mode to be drop from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self.delete(modes)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
This action will persistent the data to disk immediately.
|
||||
|
||||
This method will:
|
||||
1. Clear all data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
|
||||
import configparser
|
||||
from pymilvus import MilvusClient
|
||||
from pymilvus import MilvusClient # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all data from the Milvus collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Drop the collection and recreate it
|
||||
if self._client.has_collection(self.namespace):
|
||||
self._client.drop_collection(self.namespace)
|
||||
|
||||
# Recreate the collection
|
||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from motor.motor_asyncio import (
|
||||
from motor.motor_asyncio import ( # type: ignore
|
||||
AsyncIOMotorClient,
|
||||
AsyncIOMotorDatabase,
|
||||
AsyncIOMotorCollection,
|
||||
)
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
from pymongo.operations import SearchIndexModel # type: ignore
|
||||
from pymongo.errors import PyMongoError # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete documents with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of document IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self._data.delete_many({"_id": {"$in": ids}})
|
||||
logger.info(
|
||||
f"Deleted {result.deleted_count} documents from {self.namespace}"
|
||||
)
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting documents from {self.namespace}: {e}")
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build regex pattern to match documents with the specified modes
|
||||
pattern = f"^({'|'.join(modes)})_"
|
||||
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
||||
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
logger.debug(f"Successfully deleted edges: {edges}")
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection and recreating vector index.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
# Delete all documents
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
# Recreate vector index
|
||||
await self.create_vector_index_if_not_exists()
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped and vector index recreated",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
return self._client
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
storage = getattr(client, "_NanoVectorDB__storage")
|
||||
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
client = await self._get_client()
|
||||
return client.get(ids)
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method is intended for use in scenarios where all data needs to be removed,
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._client_file_name):
|
||||
os.remove(self._client_file_name)
|
||||
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final, Optional
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
||||
USERNAME = os.environ.get(
|
||||
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
),
|
||||
)
|
||||
DATABASE = os.environ.get(
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
||||
)
|
||||
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
||||
)
|
||||
|
||||
# Try to connect to the database
|
||||
with GraphDatabase.driver(
|
||||
URI,
|
||||
auth=(USERNAME, PASSWORD),
|
||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||
connection_timeout=CONNECTION_TIMEOUT,
|
||||
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
||||
) as _sync_driver:
|
||||
for database in (DATABASE, None):
|
||||
self._DATABASE = database
|
||||
connected = False
|
||||
# Try to connect to the database and create it if it doesn't exist
|
||||
for database in (DATABASE, None):
|
||||
self._DATABASE = database
|
||||
connected = False
|
||||
|
||||
try:
|
||||
with _sync_driver.session(database=database) as session:
|
||||
try:
|
||||
session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
logger.info(f"Connected to {database} at {URI}")
|
||||
connected = True
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
logger.error(
|
||||
f"{database} at {URI} is not available".capitalize()
|
||||
)
|
||||
raise e
|
||||
except neo4jExceptions.AuthError as e:
|
||||
logger.error(f"Authentication failed for {database} at {URI}")
|
||||
raise e
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||
logger.info(
|
||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||
try:
|
||||
async with self._driver.session(database=database) as session:
|
||||
try:
|
||||
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"Connected to {database} at {URI}")
|
||||
connected = True
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
logger.error(
|
||||
f"{database} at {URI} is not available".capitalize()
|
||||
)
|
||||
try:
|
||||
with _sync_driver.session() as session:
|
||||
session.run(
|
||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||
raise e
|
||||
except neo4jExceptions.AuthError as e:
|
||||
logger.error(f"Authentication failed for {database} at {URI}")
|
||||
raise e
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||
logger.info(
|
||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||
)
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
result = await session.run(
|
||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||
)
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"{database} at {URI} created".capitalize())
|
||||
connected = True
|
||||
except (
|
||||
neo4jExceptions.ClientError,
|
||||
neo4jExceptions.DatabaseError,
|
||||
) as e:
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
||||
if database is not None:
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||
)
|
||||
logger.info(f"{database} at {URI} created".capitalize())
|
||||
connected = True
|
||||
except (
|
||||
neo4jExceptions.ClientError,
|
||||
neo4jExceptions.DatabaseError,
|
||||
) as e:
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
) or (
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
):
|
||||
if database is not None:
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||
)
|
||||
if database is None:
|
||||
logger.error(f"Failed to create {database} at {URI}")
|
||||
raise e
|
||||
if database is None:
|
||||
logger.error(f"Failed to create {database} at {URI}")
|
||||
raise e
|
||||
|
||||
if connected:
|
||||
break
|
||||
if connected:
|
||||
# Create index for base nodes on entity_id if it doesn't exist
|
||||
try:
|
||||
async with self._driver.session(database=database) as session:
|
||||
# Check if index exists first
|
||||
check_query = """
|
||||
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
||||
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
|
||||
RETURN count(*) > 0 AS exists
|
||||
"""
|
||||
try:
|
||||
check_result = await session.run(check_query)
|
||||
record = await check_result.single()
|
||||
await check_result.consume()
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
index_exists = record and record.get("exists", False)
|
||||
|
||||
async def close(self):
|
||||
if not index_exists:
|
||||
# Create index only if it doesn't exist
|
||||
result = await session.run(
|
||||
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
logger.info(
|
||||
f"Created index for base nodes on entity_id in {database}"
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if db.indexes() is not supported in this Neo4j version
|
||||
result = await session.run(
|
||||
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create index: {str(e)}")
|
||||
break
|
||||
|
||||
async def finalize(self):
|
||||
"""Close the Neo4j driver and release all resources"""
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Ensure driver is closed when context manager exits"""
|
||||
await self.close()
|
||||
await self.finalize()
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Noe4J handles persistence automatically
|
||||
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
raise
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier.
|
||||
"""Get node by its label identifier, return only node properties
|
||||
|
||||
Args:
|
||||
node_id: The node label to look up
|
||||
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
||||
)
|
||||
# Return default edge properties when no edge found
|
||||
return {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}
|
||||
# Return None when no edge found
|
||||
return None
|
||||
finally:
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"""
|
||||
properties = node_data
|
||||
entity_type = properties["entity_type"]
|
||||
entity_id = properties["entity_id"]
|
||||
if "entity_id" not in properties:
|
||||
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
||||
|
||||
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = (
|
||||
"""
|
||||
MERGE (n:base {entity_id: $properties.entity_id})
|
||||
MERGE (n:base {entity_id: $entity_id})
|
||||
SET n += $properties
|
||||
SET n:`%s`
|
||||
"""
|
||||
% entity_type
|
||||
)
|
||||
result = await tx.run(query, properties=properties)
|
||||
result = await tx.run(
|
||||
query, entity_id=node_id, properties=properties
|
||||
)
|
||||
logger.debug(
|
||||
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
|
||||
f"Upserted node with entity_id '{node_id}' and properties: {properties}"
|
||||
)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
min_degree: Minimum degree of nodes to include. Defaults to 0
|
||||
inclusive: Do an inclusive search if true
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Complete connected subgraph for specified node
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
) as session:
|
||||
try:
|
||||
if node_label == "*":
|
||||
# First check total node count to determine if graph is truncated
|
||||
count_query = "MATCH (n) RETURN count(n) as total"
|
||||
count_result = None
|
||||
try:
|
||||
count_result = await session.run(count_query)
|
||||
count_record = await count_result.single()
|
||||
|
||||
if count_record and count_record["total"] > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
finally:
|
||||
if count_result:
|
||||
await count_result.consume()
|
||||
|
||||
# Run main query to get nodes with highest degree
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
WITH n, COALESCE(count(r), 0) AS degree
|
||||
WHERE degree >= $min_degree
|
||||
ORDER BY degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: n}) AS filtered_nodes
|
||||
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
"""
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
|
||||
)
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{"max_nodes": max_nodes},
|
||||
)
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
|
||||
else:
|
||||
# Main query uses partial matching
|
||||
main_query = """
|
||||
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
# First try without limit to check if we need to truncate
|
||||
full_query = """
|
||||
MATCH (start)
|
||||
WHERE
|
||||
CASE
|
||||
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
|
||||
ELSE start.entity_id = $entity_id
|
||||
END
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
relationshipFilter: '',
|
||||
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
bfs: true
|
||||
})
|
||||
YIELD nodes, relationships
|
||||
WITH start, nodes, relationships
|
||||
WITH nodes, relationships, size(nodes) AS total_nodes
|
||||
UNWIND nodes AS node
|
||||
OPTIONAL MATCH (node)-[r]-()
|
||||
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
|
||||
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
|
||||
ORDER BY
|
||||
CASE
|
||||
WHEN node = start THEN 3
|
||||
WHEN EXISTS((start)--(node)) THEN 2
|
||||
ELSE 1
|
||||
END DESC,
|
||||
degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: node}) AS filtered_nodes
|
||||
UNWIND filtered_nodes AS node_info
|
||||
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
||||
OPTIONAL MATCH (a)-[r]-(b)
|
||||
WHERE a IN kept_nodes AND b IN kept_nodes
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
||||
RETURN node_info, relationships, total_nodes
|
||||
"""
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{
|
||||
"max_nodes": MAX_GRAPH_NODES,
|
||||
"entity_id": node_label,
|
||||
"inclusive": inclusive,
|
||||
"max_depth": max_depth,
|
||||
"min_degree": min_degree,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
record = await result_set.single()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
for node_info in record["node_info"]:
|
||||
node = node_info["node"]
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[node.get("entity_id")],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Handle relationships (including direction information)
|
||||
for rel in record["relationships"]:
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
|
||||
# Try to get full result
|
||||
full_result = None
|
||||
try:
|
||||
full_result = await session.run(
|
||||
full_query,
|
||||
{
|
||||
"entity_id": node_label,
|
||||
"max_depth": max_depth,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await result_set.consume() # Ensure result set is consumed
|
||||
full_record = await full_result.single()
|
||||
|
||||
# If no record found, return empty KnowledgeGraph
|
||||
if not full_record:
|
||||
logger.debug(f"No nodes found for entity_id: {node_label}")
|
||||
return result
|
||||
|
||||
# If record found, check node count
|
||||
total_nodes = full_record["total_nodes"]
|
||||
|
||||
if total_nodes <= max_nodes:
|
||||
# If node count is within limit, use full result directly
|
||||
logger.debug(
|
||||
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
||||
)
|
||||
record = full_record
|
||||
else:
|
||||
# If node count exceeds limit, set truncated flag and run limited query
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
||||
)
|
||||
|
||||
# Run limited query
|
||||
limited_query = """
|
||||
MATCH (start)
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
relationshipFilter: '',
|
||||
minLevel: 0,
|
||||
maxLevel: $max_depth,
|
||||
limit: $max_nodes,
|
||||
bfs: true
|
||||
})
|
||||
YIELD nodes, relationships
|
||||
UNWIND nodes AS node
|
||||
WITH collect({node: node}) AS node_info, relationships
|
||||
RETURN node_info, relationships
|
||||
"""
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
limited_query,
|
||||
{
|
||||
"entity_id": node_label,
|
||||
"max_depth": max_depth,
|
||||
"max_nodes": max_nodes,
|
||||
},
|
||||
)
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
finally:
|
||||
if full_result:
|
||||
await full_result.consume()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
for node_info in record["node_info"]:
|
||||
node = node_info["node"]
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[node.get("entity_id")],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Handle relationships (including direction information)
|
||||
for rel in record["relationships"]:
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
logger.warning(f"APOC plugin error: {str(e)}")
|
||||
@@ -767,110 +837,28 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.warning(
|
||||
"Neo4j: falling back to basic Cypher recursive search..."
|
||||
)
|
||||
if inclusive:
|
||||
logger.warning(
|
||||
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
||||
)
|
||||
return await self._robust_fallback(
|
||||
node_label, max_depth, min_degree
|
||||
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
else:
|
||||
logger.warning(
|
||||
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _robust_fallback(
|
||||
self, node_label: str, max_depth: int, min_degree: int = 0
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Fallback implementation when APOC plugin is not available or incompatible.
|
||||
This method implements the same functionality as get_knowledge_graph but uses
|
||||
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
||||
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
result = KnowledgeGraph()
|
||||
visited_nodes = set()
|
||||
visited_edges = set()
|
||||
|
||||
async def traverse(
|
||||
node: KnowledgeGraphNode,
|
||||
edge: Optional[KnowledgeGraphEdge],
|
||||
current_depth: int,
|
||||
):
|
||||
# Check traversal limits
|
||||
if current_depth > max_depth:
|
||||
logger.debug(f"Reached max depth: {max_depth}")
|
||||
return
|
||||
if len(visited_nodes) >= MAX_GRAPH_NODES:
|
||||
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
||||
return
|
||||
|
||||
# Check if node already visited
|
||||
if node.id in visited_nodes:
|
||||
return
|
||||
|
||||
# Get all edges and target nodes
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(
|
||||
1000
|
||||
) # Max neighbour nodes we can handled
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Nodes not connected to start node need to check degree
|
||||
if current_depth > 1 and len(records) < min_degree:
|
||||
return
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(node)
|
||||
visited_nodes.add(node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if edge and edge.id not in visited_edges:
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge.id)
|
||||
|
||||
# Prepare nodes and edges for recursive processing
|
||||
nodes_to_process = []
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=list(f"{target_id}"),
|
||||
properties=dict(b_node.properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
nodes_to_process.append((target_node, target_edge))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
||||
)
|
||||
|
||||
# Process nodes after releasing database connection
|
||||
for target_node, target_edge in nodes_to_process:
|
||||
await traverse(target_node, target_edge, current_depth + 1)
|
||||
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
|
||||
|
||||
# Get the starting node's data
|
||||
async with self._driver.session(
|
||||
@@ -889,15 +877,129 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=f"{node_record['n'].get('entity_id')}",
|
||||
labels=list(f"{node_record['n'].get('entity_id')}"),
|
||||
properties=dict(node_record["n"].properties),
|
||||
labels=[node_record["n"].get("entity_id")],
|
||||
properties=dict(node_record["n"]._properties),
|
||||
)
|
||||
finally:
|
||||
await node_result.consume() # Ensure results are consumed
|
||||
|
||||
# Start traversal with the initial node
|
||||
await traverse(start_node, None, 0)
|
||||
# Initialize queue for BFS with (node, edge, depth) tuples
|
||||
# edge is None for the starting node
|
||||
queue = deque([(start_node, None, 0)])
|
||||
|
||||
# True BFS implementation using a queue
|
||||
while queue and len(visited_nodes) < max_nodes:
|
||||
# Dequeue the next node to process
|
||||
current_node, current_edge, current_depth = queue.popleft()
|
||||
|
||||
# Skip if already visited or exceeds max depth
|
||||
if current_node.id in visited_nodes:
|
||||
continue
|
||||
|
||||
if current_depth > max_depth:
|
||||
logger.debug(
|
||||
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(current_node)
|
||||
visited_nodes.add(current_node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if current_edge and current_edge.id not in visited_edges:
|
||||
result.edges.append(current_edge)
|
||||
visited_edges.add(current_edge.id)
|
||||
|
||||
# Stop if we've reached the node limit
|
||||
if len(visited_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||
)
|
||||
break
|
||||
|
||||
# Get all edges and target nodes for the current node (even at max_depth)
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=current_node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(1000) # Max neighbor nodes we can handle
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Process all neighbors - capture all edges but only queue unvisited nodes
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=[target_id],
|
||||
properties=dict(b_node._properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{current_node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
# 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
|
||||
sorted_pair = tuple(sorted([current_node.id, target_id]))
|
||||
|
||||
# 检查是否已存在相同的边(考虑无向性)
|
||||
if sorted_pair not in visited_edge_pairs:
|
||||
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
|
||||
if target_id in visited_nodes or (
|
||||
target_id not in visited_nodes
|
||||
and current_depth < max_depth
|
||||
):
|
||||
result.edges.append(target_edge)
|
||||
visited_edges.add(edge_id)
|
||||
visited_edge_pairs.add(sorted_pair)
|
||||
|
||||
# Only add unvisited nodes to the queue for further expansion
|
||||
if target_id not in visited_nodes:
|
||||
# Only add to queue if we're not at max depth yet
|
||||
if current_depth < max_depth:
|
||||
# Add node to queue with incremented depth
|
||||
# Edge is already added to result, so we pass None as edge
|
||||
queue.append((target_node, None, current_depth + 1))
|
||||
else:
|
||||
# At max depth, we've already added the edge but we don't add the node
|
||||
# This prevents adding nodes beyond max_depth to the result
|
||||
logger.debug(
|
||||
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
|
||||
)
|
||||
else:
|
||||
# If target node already exists in result, we don't need to add it again
|
||||
logger.debug(
|
||||
f"Node {target_id} already visited, edge added but node not queued"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing entity_id on target node"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
# Method 2: Query compatible with older versions
|
||||
query = """
|
||||
MATCH (n)
|
||||
MATCH (n:base)
|
||||
WHERE n.entity_id IS NOT NULL
|
||||
RETURN DISTINCT n.entity_id AS label
|
||||
ORDER BY label
|
||||
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
|
||||
This method will delete all nodes and relationships in the Neo4j database.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
# Delete all nodes and relationships
|
||||
query = "MATCH (n) DETACH DELETE n"
|
||||
result = await session.run(query)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
# TODO:deprecated, remove later
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
graph.remove_node(node_id)
|
||||
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
# TODO: NOT USED
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
min_degree: Minimum degree of nodes to include. Defaults to 0
|
||||
inclusive: Do an inclusive search if true
|
||||
node_label: Label of the starting node,* means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Initialize sets for start nodes and direct connected nodes
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
result = KnowledgeGraph()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = (
|
||||
graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
# Get degrees of all nodes
|
||||
degrees = dict(graph.degree())
|
||||
# Sort nodes by degree in descending order and take top max_nodes
|
||||
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Check if graph is truncated
|
||||
if len(sorted_nodes) > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
|
||||
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
|
||||
# Create subgraph with the highest degree nodes
|
||||
subgraph = graph.subgraph(limited_nodes)
|
||||
else:
|
||||
# Find nodes with matching node id based on search_mode
|
||||
nodes_to_explore = []
|
||||
for n, attr in graph.nodes(data=True):
|
||||
node_str = str(n)
|
||||
if not inclusive:
|
||||
if node_label == node_str: # Use exact matching
|
||||
nodes_to_explore.append(n)
|
||||
else: # inclusive mode
|
||||
if node_label in node_str: # Use partial matching
|
||||
nodes_to_explore.append(n)
|
||||
# Check if node exists
|
||||
if node_label not in graph:
|
||||
logger.warning(f"Node {node_label} not found in the graph")
|
||||
return KnowledgeGraph() # Return empty graph
|
||||
|
||||
if not nodes_to_explore:
|
||||
logger.warning(f"No nodes found with label {node_label}")
|
||||
return result
|
||||
# Use BFS to get nodes
|
||||
bfs_nodes = []
|
||||
visited = set()
|
||||
queue = [(node_label, 0)] # (node, depth) tuple
|
||||
|
||||
# Get subgraph using ego_graph from all matching nodes
|
||||
combined_subgraph = nx.Graph()
|
||||
for start_node in nodes_to_explore:
|
||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||
# Breadth-first search
|
||||
while queue and len(bfs_nodes) < max_nodes:
|
||||
current, depth = queue.pop(0)
|
||||
if current not in visited:
|
||||
visited.add(current)
|
||||
bfs_nodes.append(current)
|
||||
|
||||
# Get start nodes and direct connected nodes
|
||||
if nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(
|
||||
combined_subgraph.neighbors(start_node)
|
||||
)
|
||||
# Only explore neighbors if we haven't reached max_depth
|
||||
if depth < max_depth:
|
||||
# Add neighbor nodes to queue with incremented depth
|
||||
neighbors = list(graph.neighbors(current))
|
||||
queue.extend(
|
||||
[(n, depth + 1) for n in neighbors if n not in visited]
|
||||
)
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
# Check if graph is truncated - if we still have nodes in the queue
|
||||
# and we've reached max_nodes, then the graph is truncated
|
||||
if queue and len(bfs_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
|
||||
)
|
||||
|
||||
subgraph = combined_subgraph
|
||||
|
||||
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||
if min_degree > 0:
|
||||
nodes_to_keep = [
|
||||
node
|
||||
for node, degree in subgraph.degree()
|
||||
if node in start_nodes
|
||||
or node in direct_connected_nodes
|
||||
or degree >= min_degree
|
||||
]
|
||||
subgraph = subgraph.subgraph(nodes_to_keep)
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
node_degrees = dict(subgraph.degree())
|
||||
|
||||
def priority_key(node_item):
|
||||
node, degree = node_item
|
||||
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
||||
if node in start_nodes:
|
||||
priority = 2
|
||||
elif node in direct_connected_nodes:
|
||||
priority = 1
|
||||
else:
|
||||
priority = 0
|
||||
return (priority, degree)
|
||||
|
||||
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
|
||||
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
|
||||
:MAX_GRAPH_NODES
|
||||
]
|
||||
top_node_ids = [node[0] for node in top_nodes]
|
||||
# Create new subgraph and keep nodes only with most degree
|
||||
subgraph = subgraph.subgraph(top_node_ids)
|
||||
logger.info(
|
||||
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
|
||||
)
|
||||
# Create subgraph with BFS discovered nodes
|
||||
subgraph = graph.subgraph(bfs_nodes)
|
||||
|
||||
# Add nodes to result
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
for node in subgraph.nodes():
|
||||
if str(node) in seen_nodes:
|
||||
continue
|
||||
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
for edge in subgraph.edges():
|
||||
source, target = edge
|
||||
# Esure unique edge_id for undirect graph
|
||||
if source > target:
|
||||
if str(source) > str(target):
|
||||
source, target = target, source
|
||||
edge_id = f"{source}-{target}"
|
||||
if edge_id in seen_edges:
|
||||
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return False # Return error
|
||||
|
||||
return True
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all graph data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the graph storage file if it exists
|
||||
2. Reset the graph to an empty state
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._graphml_xml_file):
|
||||
os.remove(self._graphml_xml_file)
|
||||
self._graph = nx.Graph()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -8,17 +8,15 @@ import uuid
|
||||
from ..utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
import configparser
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("qdrant-client"):
|
||||
pm.install("qdrant-client")
|
||||
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client import QdrantClient, models # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Convert to Qdrant compatible ID
|
||||
qdrant_id = compute_mdhash_id_for_qdrant(id)
|
||||
|
||||
# Retrieve the point by ID
|
||||
result = self._client.retrieve(
|
||||
collection_name=self.namespace,
|
||||
ids=[qdrant_id],
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0].payload
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Convert to Qdrant compatible IDs
|
||||
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
|
||||
|
||||
# Retrieve the points by IDs
|
||||
results = self._client.retrieve(
|
||||
collection_name=self.namespace,
|
||||
ids=qdrant_ids,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
return [point.payload for point in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all data from the Qdrant collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Delete the collection and recreate it
|
||||
if self._client.collection_exists(self.namespace):
|
||||
self._client.delete_collection(self.namespace)
|
||||
|
||||
# Recreate the collection
|
||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
vectors_config=models.VectorParams(
|
||||
size=self.embedding_func.embedding_dim,
|
||||
distance=models.Distance.COSINE,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -12,6 +12,7 @@ if not pm.is_installed("redis"):
|
||||
from redis.asyncio import Redis, ConnectionPool
|
||||
from redis.exceptions import RedisError, ConnectionError
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
|
||||
from lightrag.base import BaseKVStorage
|
||||
import json
|
||||
|
||||
@@ -121,7 +122,11 @@ class RedisKVStorage(BaseKVStorage):
|
||||
except json.JSONEncodeError as e:
|
||||
logger.error(f"JSON encode error during upsert: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Redis handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete entries with specified IDs"""
|
||||
if not ids:
|
||||
@@ -138,71 +143,52 @@ class RedisKVStorage(BaseKVStorage):
|
||||
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete an entity by name"""
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
|
||||
Importance notes for Redis storage:
|
||||
1. This will immediately delete the specified cache modes from Redis
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache mode to be drop from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
await self.delete(modes)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async with self._get_redis_connection() as redis:
|
||||
result = await redis.delete(f"{self.namespace}:{entity_id}")
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all keys under the current namespace.
|
||||
|
||||
if result:
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
keys = await redis.keys(f"{self.namespace}:*")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete all relations associated with an entity"""
|
||||
try:
|
||||
async with self._get_redis_connection() as redis:
|
||||
cursor = 0
|
||||
relation_keys = []
|
||||
pattern = f"{self.namespace}:*"
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=pattern)
|
||||
|
||||
# Process keys in batches
|
||||
if keys:
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if (
|
||||
data.get("src_id") == entity_name
|
||||
or data.get("tgt_id") == entity_name
|
||||
):
|
||||
relation_keys.append(key)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON in key {key}")
|
||||
continue
|
||||
pipe.delete(key)
|
||||
results = await pipe.execute()
|
||||
deleted_count = sum(results)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Delete relations in batches
|
||||
if relation_keys:
|
||||
# Delete in chunks to avoid too many arguments
|
||||
chunk_size = 1000
|
||||
for i in range(0, len(relation_keys), chunk_size):
|
||||
chunk = relation_keys[i:i + chunk_size]
|
||||
deleted = await redis.delete(*chunk)
|
||||
logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
|
||||
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
||||
return {"status": "success", "message": f"{deleted_count} keys dropped"}
|
||||
else:
|
||||
logger.debug(f"No relations found for entity {entity_name}")
|
||||
logger.info(f"No keys found to drop in {self.namespace}")
|
||||
return {"status": "success", "message": "no keys to drop"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Redis handles persistence automatically
|
||||
pass
|
||||
|
@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy import create_engine, text # type: ignore
|
||||
|
||||
|
||||
class TiDB:
|
||||
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete records with specified IDs from the storage.
|
||||
|
||||
Args:
|
||||
ids: List of record IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
id_field = namespace_to_id(self.namespace)
|
||||
|
||||
if not table_name or not id_field:
|
||||
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
||||
return
|
||||
|
||||
ids_list = ",".join([f"'{id}'" for id in ids])
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
||||
|
||||
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
||||
logger.info(
|
||||
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting records from {self.namespace}: {e}")
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return False
|
||||
|
||||
if table_name != "LIGHTRAG_LLM_CACHE":
|
||||
return False
|
||||
|
||||
# 构建MySQL风格的IN查询
|
||||
modes_list = ", ".join([f"'{mode}'" for mode in modes])
|
||||
sql = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE workspace = :workspace
|
||||
AND mode IN ({modes_list})
|
||||
"""
|
||||
|
||||
logger.info(f"Deleting cache by modes: {modes}")
|
||||
await self.db.execute(sql, {"workspace": self.db.workspace})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unknown namespace: {self.namespace}",
|
||||
}
|
||||
|
||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||
table_name=table_name
|
||||
)
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete vectors with specified IDs from the storage.
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
id_field = namespace_to_id(self.namespace)
|
||||
|
||||
if not table_name or not id_field:
|
||||
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
||||
return
|
||||
|
||||
ids_list = ",".join([f"'{id}'" for id in ids])
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
||||
|
||||
try:
|
||||
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete an entity by its name from the vector storage.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity to delete
|
||||
"""
|
||||
try:
|
||||
# Construct SQL to delete the entity
|
||||
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE workspace = :workspace AND name = :entity_name"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete all relations associated with an entity.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity whose relations should be deleted
|
||||
"""
|
||||
try:
|
||||
# Delete relations where the entity is either the source or target
|
||||
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted relations for entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unknown namespace: {self.namespace}",
|
||||
}
|
||||
|
||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||
table_name=table_name
|
||||
)
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
drop_sql = """
|
||||
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
|
||||
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
|
||||
"""
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Delete a node and all its related edges
|
||||
|
||||
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
""",
|
||||
# Drop tables
|
||||
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
|
||||
}
|
||||
|
@@ -13,7 +13,6 @@ import pandas as pd
|
||||
|
||||
|
||||
from lightrag.kg import (
|
||||
STORAGE_ENV_REQUIREMENTS,
|
||||
STORAGES,
|
||||
verify_storage_implementation,
|
||||
)
|
||||
@@ -230,6 +229,7 @@ class LightRAG:
|
||||
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional parameters for vector database storage."""
|
||||
|
||||
# TODO:deprecated, remove in the future, use WORKSPACE instead
|
||||
namespace_prefix: str = field(default="")
|
||||
"""Prefix for namespacing stored data across different environments."""
|
||||
|
||||
@@ -510,36 +510,22 @@ class LightRAG:
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = 1000,
|
||||
) -> KnowledgeGraph:
|
||||
"""Get knowledge graph for a given label
|
||||
|
||||
Args:
|
||||
node_label (str): Label to get knowledge graph for
|
||||
max_depth (int): Maximum depth of graph
|
||||
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
|
||||
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
|
||||
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Knowledge graph containing nodes and edges
|
||||
"""
|
||||
# get params supported by get_knowledge_graph of specified storage
|
||||
import inspect
|
||||
|
||||
storage_params = inspect.signature(
|
||||
self.chunk_entity_relation_graph.get_knowledge_graph
|
||||
).parameters
|
||||
|
||||
kwargs = {"node_label": node_label, "max_depth": max_depth}
|
||||
|
||||
if "min_degree" in storage_params and min_degree > 0:
|
||||
kwargs["min_degree"] = min_degree
|
||||
|
||||
if "inclusive" in storage_params:
|
||||
kwargs["inclusive"] = inclusive
|
||||
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||
node_label, max_depth, max_nodes
|
||||
)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
import_path = STORAGES[storage_name]
|
||||
@@ -1449,6 +1435,7 @@ class LightRAG:
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def adelete_by_entity(self, entity_name: str) -> None:
|
||||
try:
|
||||
await self.entities_vdb.delete_entity(entity_name)
|
||||
@@ -1486,6 +1473,7 @@ class LightRAG:
|
||||
self.adelete_by_relation(source_entity, target_entity)
|
||||
)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
||||
"""Asynchronously delete a relation between two entities.
|
||||
|
||||
@@ -1494,6 +1482,7 @@ class LightRAG:
|
||||
target_entity: Name of the target entity
|
||||
"""
|
||||
try:
|
||||
# TODO: check if has_edge function works on reverse relation
|
||||
# Check if the relation exists
|
||||
edge_exists = await self.chunk_entity_relation_graph.has_edge(
|
||||
source_entity, target_entity
|
||||
@@ -1554,6 +1543,7 @@ class LightRAG:
|
||||
"""
|
||||
return await self.doc_status.get_docs_by_status(status)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
||||
"""Delete a document and all its related data
|
||||
|
||||
@@ -1586,6 +1576,8 @@ class LightRAG:
|
||||
chunk_ids = set(related_chunks.keys())
|
||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||
|
||||
# TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
|
||||
|
||||
# 3. Before deleting, check the related entities and relationships for these chunks
|
||||
for chunk_id in chunk_ids:
|
||||
# Check entities
|
||||
@@ -1857,24 +1849,6 @@ class LightRAG:
|
||||
|
||||
return result
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
||||
"""Clear cache data from the LLM response cache storage.
|
||||
|
||||
@@ -1906,12 +1880,18 @@ class LightRAG:
|
||||
try:
|
||||
# Reset the cache storage for specified mode
|
||||
if modes:
|
||||
await self.llm_response_cache.delete(modes)
|
||||
logger.info(f"Cleared cache for modes: {modes}")
|
||||
success = await self.llm_response_cache.drop_cache_by_modes(modes)
|
||||
if success:
|
||||
logger.info(f"Cleared cache for modes: {modes}")
|
||||
else:
|
||||
logger.warning(f"Failed to clear cache for modes: {modes}")
|
||||
else:
|
||||
# Clear all modes
|
||||
await self.llm_response_cache.delete(valid_modes)
|
||||
logger.info("Cleared all cache")
|
||||
success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
|
||||
if success:
|
||||
logger.info("Cleared all cache")
|
||||
else:
|
||||
logger.warning("Failed to clear all cache")
|
||||
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
|
||||
@@ -1922,6 +1902,7 @@ class LightRAG:
|
||||
"""Synchronous version of aclear_cache."""
|
||||
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def aedit_entity(
|
||||
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
||||
) -> dict[str, Any]:
|
||||
@@ -2134,6 +2115,7 @@ class LightRAG:
|
||||
]
|
||||
)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def aedit_relation(
|
||||
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
@@ -2448,6 +2430,7 @@ class LightRAG:
|
||||
self.acreate_relation(source_entity, target_entity, relation_data)
|
||||
)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def amerge_entities(
|
||||
self,
|
||||
source_entities: list[str],
|
||||
|
@@ -44,6 +44,47 @@ class InvalidResponseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def create_openai_async_client(
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
client_configs: dict[str, Any] = None,
|
||||
) -> AsyncOpenAI:
|
||||
"""Create an AsyncOpenAI client with the given configuration.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
||||
Returns:
|
||||
An AsyncOpenAI client instance.
|
||||
"""
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if client_configs is None:
|
||||
client_configs = {}
|
||||
|
||||
# Create a merged config dict with precedence: explicit params > client_configs > defaults
|
||||
merged_configs = {
|
||||
**client_configs,
|
||||
"default_headers": default_headers,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if base_url is not None:
|
||||
merged_configs["base_url"] = base_url
|
||||
|
||||
return AsyncOpenAI(**merged_configs)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@@ -61,29 +102,52 @@ async def openai_complete_if_cache(
|
||||
token_tracker: Any | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Complete a prompt using OpenAI's API with caching support.
|
||||
|
||||
Args:
|
||||
model: The OpenAI model to use.
|
||||
prompt: The prompt to complete.
|
||||
system_prompt: Optional system prompt to include.
|
||||
history_messages: Optional list of previous messages in the conversation.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
||||
Special kwargs:
|
||||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||
These will be passed to the client constructor but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
||||
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
||||
|
||||
Returns:
|
||||
The completed text or an async iterator of text chunks if streaming.
|
||||
|
||||
Raises:
|
||||
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
||||
APIConnectionError: If there is a connection error with the OpenAI API.
|
||||
RateLimitError: If the OpenAI API rate limit is exceeded.
|
||||
APITimeoutError: If the OpenAI API request times out.
|
||||
"""
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
||||
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
||||
logging.getLogger("openai").setLevel(logging.INFO)
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||
if base_url is None
|
||||
else AsyncOpenAI(
|
||||
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||
)
|
||||
# Extract client configuration options
|
||||
client_configs = kwargs.pop("openai_client_configs", {})
|
||||
|
||||
# Create the OpenAI client
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key, base_url=base_url, client_configs=client_configs
|
||||
)
|
||||
|
||||
# Remove special kwargs that shouldn't be passed to OpenAI
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
|
||||
# Prepare messages
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
@@ -272,21 +336,32 @@ async def openai_embed(
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
client_configs: dict[str, Any] = None,
|
||||
) -> np.ndarray:
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
"""Generate embeddings for a list of texts using OpenAI's API.
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
openai_async_client = (
|
||||
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||
if base_url is None
|
||||
else AsyncOpenAI(
|
||||
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||
)
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The OpenAI embedding model to use.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If there is a connection error with the OpenAI API.
|
||||
RateLimitError: If the OpenAI API rate limit is exceeded.
|
||||
APITimeoutError: If the OpenAI API request times out.
|
||||
"""
|
||||
# Create the OpenAI client
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key, base_url=base_url, client_configs=client_configs
|
||||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
|
@@ -26,7 +26,6 @@ from .utils import (
|
||||
CacheData,
|
||||
statistic_data,
|
||||
get_conversation_turns,
|
||||
verbose_debug,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -442,6 +441,13 @@ async def extract_entities(
|
||||
|
||||
processed_chunks = 0
|
||||
total_chunks = len(ordered_chunks)
|
||||
total_entities_count = 0
|
||||
total_relations_count = 0
|
||||
|
||||
# Get lock manager from shared storage
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
async def _user_llm_func_with_cache(
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
@@ -540,7 +546,7 @@ async def extract_entities(
|
||||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||||
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||||
"""
|
||||
nonlocal processed_chunks
|
||||
nonlocal processed_chunks, total_entities_count, total_relations_count
|
||||
chunk_key = chunk_key_dp[0]
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
@@ -598,102 +604,74 @@ async def extract_entities(
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||
results = await asyncio.gather(*tasks)
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
chunk_entities_data = []
|
||||
chunk_relationships_data = []
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for m_nodes, m_edges in results:
|
||||
for k, v in m_nodes.items():
|
||||
maybe_nodes[k].extend(v)
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
# Ensure that nodes and edges are merged and upserted atomically
|
||||
async with graph_db_lock:
|
||||
all_entities_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
)
|
||||
|
||||
all_relationships_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_edges_then_upsert(
|
||||
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||
async with graph_db_lock:
|
||||
# Process and update entities
|
||||
for entity_name, entities in maybe_nodes.items():
|
||||
entity_data = await _merge_nodes_then_upsert(
|
||||
entity_name, entities, knowledge_graph_inst, global_config
|
||||
)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
)
|
||||
chunk_entities_data.append(entity_data)
|
||||
|
||||
if not (all_entities_data or all_relationships_data):
|
||||
log_message = "Didn't extract any entities and relationships."
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return
|
||||
# Process and update relationships
|
||||
for edge_key, edges in maybe_edges.items():
|
||||
# Ensure edge direction consistency
|
||||
sorted_edge_key = tuple(sorted(edge_key))
|
||||
edge_data = await _merge_edges_then_upsert(
|
||||
sorted_edge_key[0],
|
||||
sorted_edge_key[1],
|
||||
edges,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
)
|
||||
chunk_relationships_data.append(edge_data)
|
||||
|
||||
if not all_entities_data:
|
||||
log_message = "Didn't extract any entities"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if not all_relationships_data:
|
||||
log_message = "Didn't extract any relationships"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
# Update vector database (within the same lock to ensure atomicity)
|
||||
if entity_vdb is not None and chunk_entities_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"entity_name": dp["entity_name"],
|
||||
"entity_type": dp["entity_type"],
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in chunk_entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
|
||||
if relationships_vdb is not None and chunk_relationships_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"keywords": dp["keywords"],
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in chunk_relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
# Update counters
|
||||
total_entities_count += len(chunk_entities_data)
|
||||
total_relations_count += len(chunk_relationships_data)
|
||||
|
||||
# Handle all chunks in parallel
|
||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
verbose_debug(
|
||||
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
||||
)
|
||||
verbose_debug(f"New relationships:{all_relationships_data}")
|
||||
|
||||
if entity_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"entity_name": dp["entity_name"],
|
||||
"entity_type": dp["entity_type"],
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in all_entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
if relationships_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"keywords": dp["keywords"],
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
|
||||
async def kg_query(
|
||||
@@ -720,8 +698,7 @@ async def kg_query(
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
@@ -817,6 +794,38 @@ async def kg_query(
|
||||
return response
|
||||
|
||||
|
||||
async def get_keywords_from_query(
|
||||
query: str,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Retrieves high-level and low-level keywords for RAG operations.
|
||||
|
||||
This function checks if keywords are already provided in query parameters,
|
||||
and if not, extracts them from the query text using LLM.
|
||||
|
||||
Args:
|
||||
query: The user's query text
|
||||
query_param: Query parameters that may contain pre-defined keywords
|
||||
global_config: Global configuration dictionary
|
||||
hashing_kv: Optional key-value storage for caching results
|
||||
|
||||
Returns:
|
||||
A tuple containing (high_level_keywords, low_level_keywords)
|
||||
"""
|
||||
# Check if pre-defined keywords are already provided
|
||||
if query_param.hl_keywords or query_param.ll_keywords:
|
||||
return query_param.hl_keywords, query_param.ll_keywords
|
||||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
return hl_keywords, ll_keywords
|
||||
|
||||
|
||||
async def extract_keywords_only(
|
||||
text: str,
|
||||
param: QueryParam,
|
||||
@@ -957,8 +966,7 @@ async def mix_kg_vector_query(
|
||||
# 2. Execute knowledge graph and vector searches in parallel
|
||||
async def get_kg_context():
|
||||
try:
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
@@ -1339,7 +1347,9 @@ async def _get_node_data(
|
||||
|
||||
text_units_section_list = [["id", "content", "file_path"]]
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"], t["file_path"]])
|
||||
text_units_section_list.append(
|
||||
[i, t["content"], t.get("file_path", "unknown_source")]
|
||||
)
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context, relations_context, text_units_context
|
||||
|
||||
@@ -2043,16 +2053,13 @@ async def query_with_keywords(
|
||||
Query response or async iterator
|
||||
"""
|
||||
# Extract keywords
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
param=param,
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query=query,
|
||||
query_param=param,
|
||||
global_config=global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# Create a new string with the prompt and the keywords
|
||||
ll_keywords_str = ", ".join(ll_keywords)
|
||||
hl_keywords_str = ", ".join(hl_keywords)
|
||||
|
@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
|
||||
class KnowledgeGraph(BaseModel):
|
||||
nodes: list[KnowledgeGraphNode] = []
|
||||
edges: list[KnowledgeGraphEdge] = []
|
||||
is_truncated: bool = False
|
||||
|
Reference in New Issue
Block a user