Merge pull request #744 from danielaskdd/select-datastore-in-api-server

Add datastore selection feature for API Server
This commit is contained in:
zrguo
2025-02-13 20:04:16 +08:00
committed by GitHub
23 changed files with 897 additions and 444 deletions

View File

@@ -1,12 +1,30 @@
# Server Configuration ### Server Configuration
HOST=0.0.0.0 #HOST=0.0.0.0
PORT=9621 #PORT=9621
#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# Directory Configuration ### Optional SSL Configuration
WORKING_DIR=/app/data/rag_storage #SSL=true
INPUT_DIR=/app/data/inputs #SSL_CERTFILE=/path/to/cert.pem
#SSL_KEYFILE=/path/to/key.pem
# RAG Configuration ### Security (empty for no api-key is needed)
# LIGHTRAG_API_KEY=your-secure-api-key-here
### Directory Configuration
# WORKING_DIR=./rag_storage
# INPUT_DIR=./inputs
### Logging level
LOG_LEVEL=INFO
### Optional Timeout
TIMEOUT=300
# Ollama Emulating Model Tag
# OLLAMA_EMULATING_MODEL_TAG=latest
### RAG Configuration
MAX_ASYNC=4 MAX_ASYNC=4
MAX_TOKENS=32768 MAX_TOKENS=32768
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
@@ -14,56 +32,42 @@ MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3 #HISTORY_TURNS=3
#CHUNK_SIZE=1200 #CHUNK_SIZE=1200
#CHUNK_OVERLAP_SIZE=100 #CHUNK_OVERLAP_SIZE=100
#COSINE_THRESHOLD=0.4 # 0.2 while not running API server #COSINE_THRESHOLD=0.2
#TOP_K=50 # 60 while not running API server #TOP_K=60
# LLM Configuration (Use valid host. For local services, you can use host.docker.internal) ### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
# Ollama example ### Ollama example
LLM_BINDING=ollama LLM_BINDING=ollama
LLM_BINDING_HOST=http://host.docker.internal:11434 LLM_BINDING_HOST=http://host.docker.internal:11434
LLM_MODEL=mistral-nemo:latest LLM_MODEL=mistral-nemo:latest
# OpenAI alike example ### OpenAI alike example
# LLM_BINDING=openai # LLM_BINDING=openai
# LLM_MODEL=deepseek-chat # LLM_MODEL=deepseek-chat
# LLM_BINDING_HOST=https://api.deepseek.com # LLM_BINDING_HOST=https://api.deepseek.com
# LLM_BINDING_API_KEY=your_api_key # LLM_BINDING_API_KEY=your_api_key
# for OpenAI LLM (LLM_BINDING_API_KEY take priority) ### for OpenAI LLM (LLM_BINDING_API_KEY take priority)
# OPENAI_API_KEY=your_api_key # OPENAI_API_KEY=your_api_key
# Lollms example ### Lollms example
# LLM_BINDING=lollms # LLM_BINDING=lollms
# LLM_BINDING_HOST=http://host.docker.internal:9600 # LLM_BINDING_HOST=http://host.docker.internal:9600
# LLM_MODEL=mistral-nemo:latest # LLM_MODEL=mistral-nemo:latest
# Embedding Configuration (Use valid host. For local services, you can use host.docker.internal) ### Embedding Configuration (Use valid host. For local services, you can use host.docker.internal)
# Ollama example # Ollama example
EMBEDDING_BINDING=ollama EMBEDDING_BINDING=ollama
EMBEDDING_BINDING_HOST=http://host.docker.internal:11434 EMBEDDING_BINDING_HOST=http://host.docker.internal:11434
EMBEDDING_MODEL=bge-m3:latest EMBEDDING_MODEL=bge-m3:latest
# Lollms example ### Lollms example
# EMBEDDING_BINDING=lollms # EMBEDDING_BINDING=lollms
# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600 # EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
# EMBEDDING_MODEL=bge-m3:latest # EMBEDDING_MODEL=bge-m3:latest
# Security (empty for no key) ### Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
LIGHTRAG_API_KEY=your-secure-api-key-here
# Logging
LOG_LEVEL=INFO
# Optional SSL Configuration
#SSL=true
#SSL_CERTFILE=/path/to/cert.pem
#SSL_KEYFILE=/path/to/key.pem
# Optional Timeout
#TIMEOUT=30
# Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
# AZURE_OPENAI_API_VERSION=2024-08-01-preview # AZURE_OPENAI_API_VERSION=2024-08-01-preview
# AZURE_OPENAI_DEPLOYMENT=gpt-4o # AZURE_OPENAI_DEPLOYMENT=gpt-4o
# AZURE_OPENAI_API_KEY=myapikey # AZURE_OPENAI_API_KEY=myapikey
@@ -72,6 +76,57 @@ LOG_LEVEL=INFO
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
# AZURE_EMBEDDING_API_VERSION=2023-05-15 # AZURE_EMBEDDING_API_VERSION=2023-05-15
### Data storage selection
# LIGHTRAG_KV_STORAGE=PGKVStorage
# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage
# LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
# Ollama Emulating Model Tag ### Oracle Database Configuration
# OLLAMA_EMULATING_MODEL_TAG=latest ORACLE_DSN=localhost:1521/XEPDB1
ORACLE_USER=your_username
ORACLE_PASSWORD='your_password'
ORACLE_CONFIG_DIR=/path/to/oracle/config
#ORACLE_WALLET_LOCATION=/path/to/wallet # optional
#ORACLE_WALLET_PASSWORD='your_password' # optional
#ORACLE_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
### TiDB Configuration
TIDB_HOST=localhost
TIDB_PORT=4000
TIDB_USER=your_username
TIDB_PASSWORD='your_password'
TIDB_DATABASE=your_database
#TIDB_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
### PostgreSQL Configuration
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
POSTGRES_USER=your_username
POSTGRES_PASSWORD='your_password'
POSTGRES_DATABASE=your_database
#POSTGRES_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
AGE_POSTGRES_DB=
AGE_POSTGRES_USER=
AGE_POSTGRES_PASSWORD=
AGE_POSTGRES_HOST=
# AGE_POSTGRES_PORT=8529
# AGE Graph Name(apply to PostgreSQL and independent AGM)
# AGE_GRAPH_NAME=lightrag # deprecated, use NAME_SPACE_PREFIX instead
### Neo4j Configuration
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
NEO4J_USERNAME=neo4j
NEO4J_PASSWORD='your_password'
### MongoDB Configuration
MONGODB_URI=mongodb://root:root@localhost:27017/
MONGODB_DATABASE=LightRAG
MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
### Qdrant
QDRANT_URL=http://localhost:16333
QDRANT_API_KEY=your-api-key # 可选

View File

@@ -13,3 +13,28 @@ uri=redis://localhost:6379/1
[qdrant] [qdrant]
uri = http://localhost:16333 uri = http://localhost:16333
[oracle]
dsn = localhost:1521/XEPDB1
user = your_username
password = your_password
config_dir = /path/to/oracle/config
wallet_location = /path/to/wallet # 可选
wallet_password = your_wallet_password # 可选
workspace = default # 可选,默认为default
[tidb]
host = localhost
port = 4000
user = your_username
password = your_password
database = your_database
workspace = default # 可选,默认为default
[postgres]
host = localhost
port = 5432
user = your_username
password = your_password
database = your_database
workspace = default # 可选,默认为default

View File

@@ -103,66 +103,23 @@ After starting the lightrag-server, you can add an Ollama-type connection in the
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables. LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM. Default `TOP_K` is set to `60`. Default `COSINE_THRESHOLD` are set to `0.2`.
### Environment Variables ### Environment Variables
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables: You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. A sample file `.env.example` is provided for your convenience.
```env ### Config.ini
# Server Configuration
HOST=0.0.0.0
PORT=9621
# Directory Configuration Datastorage configuration can be also set by config.ini. A sample file `config.ini.example` is provided for your convenience.
WORKING_DIR=/app/data/rag_storage
INPUT_DIR=/app/data/inputs
# RAG Configuration
MAX_ASYNC=4
MAX_TOKENS=32768
EMBEDDING_DIM=1024
MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3
#CHUNK_SIZE=1200
#CHUNK_OVERLAP_SIZE=100
#COSINE_THRESHOLD=0.4
#TOP_K=50
# LLM Configuration
LLM_BINDING=ollama
LLM_BINDING_HOST=http://localhost:11434
LLM_MODEL=mistral-nemo:latest
# must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms)
OPENAI_API_KEY=you_api_key
# Embedding Configuration
EMBEDDING_BINDING=ollama
EMBEDDING_BINDING_HOST=http://localhost:11434
EMBEDDING_MODEL=bge-m3:latest
# Security
#LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG
# Logging
LOG_LEVEL=INFO
# Optional SSL Configuration
#SSL=true
#SSL_CERTFILE=/path/to/cert.pem
#SSL_KEYFILE=/path/to/key.pem
# Optional Timeout
#TIMEOUT=30
```
### Configuration Priority ### Configuration Priority
The configuration values are loaded in the following order (highest priority first): The configuration values are loaded in the following order (highest priority first):
1. Command-line arguments 1. Command-line arguments
2. Environment variables 2. Environment variables
3. Default values 3. Config.ini
4. Defaul values
For example: For example:
```bash ```bash
@@ -173,7 +130,69 @@ python lightrag.py --port 8080
PORT=7000 python lightrag.py PORT=7000 python lightrag.py
``` ```
#### LightRag Server Options > Best practices: you can set your database setting in Config.ini while testing, and you use .env for production.
### Storage Types Supported
LightRAG uses 4 types of storage for difference purposes:
* KV_STORAGEllm response cache, text chunks, document information
* VECTOR_STORAGEentities vectors, relation vectors, chunks vectors
* GRAPH_STORAGEentity relation graph
* DOC_STATUS_STORAGEdocuments indexing status
Each storage type have servals implementations:
* KV_STORAGE supported implement-name
```
JsonKVStorage JsonFile(default)
MongoKVStorage MogonDB
RedisKVStorage Redis
TiDBKVStorage TiDB
PGKVStorage Postgres
OracleKVStorage Oracle
```
* GRAPH_STORAGE supported implement-name
```
NetworkXStorage NetworkX(defualt)
Neo4JStorage Neo4J
MongoGraphStorage MongoDB
TiDBGraphStorage TiDB
AGEStorage AGE
GremlinStorage Gremlin
PGGraphStorage Postgres
OracleGraphStorage Postgres
```
* VECTOR_STORAGE supported implement-name
```
NanoVectorDBStorage NanoVector(default)
MilvusVectorDBStorge Milvus
ChromaVectorDBStorage Chroma
TiDBVectorDBStorage TiDB
PGVectorStorage Postgres
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
OracleVectorDBStorag Oracle
```
* DOC_STATUS_STORAGEsupported implement-name
```
JsonDocStatusStorage JsonFile(default)
PGDocStatusStorage Postgres
MongoDocStatusStorage MongoDB
```
### How Select Storage Implementation
You can select storage implementation by enviroment variables or command line arguments. You can not change storage implementation selection after you add documents to LightRAG. Data migration from one storage implementation to anthor is not supported yet. For further information please read the sample env file or config.ini file.
### LightRag API Server Comand Line Options
| Parameter | Default | Description | | Parameter | Default | Description |
|-----------|---------|-------------| |-----------|---------|-------------|
@@ -200,6 +219,10 @@ PORT=7000 python lightrag.py
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. | | --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
| --kv-storage | JsonKVStorage | implement-name of KV_STORAGE |
| --graph-storage | NetworkXStorage | implement-name of GRAPH_STORAGE |
| --vector-storage | NanoVectorDBStorage | implement-name of VECTOR_STORAGE |
| --doc-status-storage | JsonDocStatusStorage | implement-name of DOC_STATUS_STORAGE |
### Example Usage ### Example Usage
@@ -343,6 +366,14 @@ curl -X POST "http://localhost:9621/documents/scan" --max-time 1800
> Ajust max-time according to the estimated index time for all new files. > Ajust max-time according to the estimated index time for all new files.
#### DELETE /documents
Clear all documents from the RAG system.
```bash
curl -X DELETE "http://localhost:9621/documents"
```
### Ollama Emulation Endpoints ### Ollama Emulation Endpoints
#### GET /api/version #### GET /api/version
@@ -372,14 +403,6 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md) > For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
#### DELETE /documents
Clear all documents from the RAG system.
```bash
curl -X DELETE "http://localhost:9621/documents"
```
### Utility Endpoints ### Utility Endpoints
#### GET /health #### GET /health

View File

@@ -1 +1 @@
__api_version__ = "1.0.4" __api_version__ = "1.0.5"

View File

@@ -26,7 +26,6 @@ import shutil
import aiofiles import aiofiles
from ascii_colors import trace_exception, ASCIIColors from ascii_colors import trace_exception, ASCIIColors
import sys import sys
import configparser
from fastapi import Depends, Security from fastapi import Depends, Security
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -34,25 +33,47 @@ from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm import pipmaster as pm
from dotenv import load_dotenv from dotenv import load_dotenv
import configparser
from lightrag.utils import logger
from .ollama_api import ( from .ollama_api import (
OllamaAPI, OllamaAPI,
) )
from .ollama_api import ollama_server_infos from .ollama_api import ollama_server_infos
from ..kg.postgres_impl import (
PostgreSQLDB,
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from ..kg.oracle_impl import (
OracleDB,
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from ..kg.tidb_impl import (
TiDB,
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage,
)
# Load environment variables # Load environment variables
load_dotenv(override=True) load_dotenv(override=True)
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
class RAGStorageConfig:
class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage" KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage" VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
# Initialize rag storage config
rag_storage_config = RAGStorageConfig()
# Global progress tracker # Global progress tracker
scan_progress: Dict = { scan_progress: Dict = {
"is_scanning": False, "is_scanning": False,
@@ -80,61 +101,6 @@ def estimate_tokens(text: str) -> int:
return int(tokens) return int(tokens)
# read config.ini
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Redis config
redis_uri = config.get("redis", "uri", fallback=None)
if redis_uri:
os.environ["REDIS_URI"] = redis_uri
rag_storage_config.KV_STORAGE = "RedisKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
# Neo4j config
neo4j_uri = config.get("neo4j", "uri", fallback=None)
neo4j_username = config.get("neo4j", "username", fallback=None)
neo4j_password = config.get("neo4j", "password", fallback=None)
if neo4j_uri:
os.environ["NEO4J_URI"] = neo4j_uri
os.environ["NEO4J_USERNAME"] = neo4j_username
os.environ["NEO4J_PASSWORD"] = neo4j_password
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
# Milvus config
milvus_uri = config.get("milvus", "uri", fallback=None)
milvus_user = config.get("milvus", "user", fallback=None)
milvus_password = config.get("milvus", "password", fallback=None)
milvus_db_name = config.get("milvus", "db_name", fallback=None)
if milvus_uri:
os.environ["MILVUS_URI"] = milvus_uri
os.environ["MILVUS_USER"] = milvus_user
os.environ["MILVUS_PASSWORD"] = milvus_password
os.environ["MILVUS_DB_NAME"] = milvus_db_name
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorage"
# Qdrant config
qdrant_uri = config.get("qdrant", "uri", fallback=None)
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
if qdrant_uri:
os.environ["QDRANT_URL"] = qdrant_uri
if qdrant_api_key:
os.environ["QDRANT_API_KEY"] = qdrant_api_key
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
# MongoDB config
mongo_uri = config.get("mongodb", "uri", fallback=None)
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
if mongo_graph:
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
def get_default_host(binding_type: str) -> str: def get_default_host(binding_type: str) -> str:
default_hosts = { default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
@@ -247,6 +213,16 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.top_k}") ASCIIColors.yellow(f"{args.top_k}")
# System Configuration # System Configuration
ASCIIColors.magenta("\n💾 Storage Configuration:")
ASCIIColors.white(" ├─ KV Storage: ", end="")
ASCIIColors.yellow(f"{args.kv_storage}")
ASCIIColors.white(" ├─ Vector Storage: ", end="")
ASCIIColors.yellow(f"{args.vector_storage}")
ASCIIColors.white(" ├─ Graph Storage: ", end="")
ASCIIColors.yellow(f"{args.graph_storage}")
ASCIIColors.white(" └─ Document Status Storage: ", end="")
ASCIIColors.yellow(f"{args.doc_status_storage}")
ASCIIColors.magenta("\n🛠️ System Configuration:") ASCIIColors.magenta("\n🛠️ System Configuration:")
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
@@ -344,6 +320,35 @@ def parse_args() -> argparse.Namespace:
description="LightRAG FastAPI Server with separate working and input directories" description="LightRAG FastAPI Server with separate working and input directories"
) )
parser.add_argument(
"--kv-storage",
default=get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
),
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
)
parser.add_argument(
"--doc-status-storage",
default=get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
),
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
)
parser.add_argument(
"--graph-storage",
default=get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
),
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
)
parser.add_argument(
"--vector-storage",
default=get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
),
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
)
# Bindings configuration # Bindings configuration
parser.add_argument( parser.add_argument(
"--llm-binding", "--llm-binding",
@@ -528,13 +533,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--top-k", "--top-k",
type=int, type=int,
default=get_env_value("TOP_K", 50, int), default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 50)", help="Number of most similar results to return (default: from env or 60)",
) )
parser.add_argument( parser.add_argument(
"--cosine-threshold", "--cosine-threshold",
type=float, type=float,
default=get_env_value("COSINE_THRESHOLD", 0.4, float), default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)", help="Cosine similarity threshold (default: from env or 0.4)",
) )
@@ -667,7 +672,14 @@ def get_api_key_dependency(api_key: Optional[str]):
return api_key_auth return api_key_auth
# Global configuration
global_top_k = 60 # default value
def create_app(args): def create_app(args):
global global_top_k
global_top_k = args.top_k # save top_k from args
# Verify that bindings are correctly setup # Verify that bindings are correctly setup
if args.llm_binding not in [ if args.llm_binding not in [
"lollms", "lollms",
@@ -713,25 +725,104 @@ def create_app(args):
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events""" """Lifespan context manager for startup and shutdown events"""
# Startup logic # Initialize database connections
if args.auto_scan_at_startup: postgres_db = None
try: oracle_db = None
new_files = doc_manager.scan_directory_for_new_files() tidb_db = None
for file_path in new_files:
try:
await index_file(file_path)
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
ASCIIColors.info( try:
f"Indexed {len(new_files)} documents from {args.input_dir}" # Check if PostgreSQL is needed
if any(
isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
) )
except Exception as e: for _, storage_instance in storage_instances
logging.error(f"Error during startup indexing: {str(e)}") ):
yield postgres_db = PostgreSQLDB(_get_postgres_config())
# Cleanup logic (if needed) await postgres_db.initdb()
pass await postgres_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
),
):
storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}")
# Check if Oracle is needed
if any(
isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
)
for _, storage_instance in storage_instances
):
oracle_db = OracleDB(_get_oracle_config())
await oracle_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage_instance.db = oracle_db
logger.info(f"Injected oracle_db to {storage_name}")
# Check if TiDB is needed
if any(
isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
)
for _, storage_instance in storage_instances
):
tidb_db = TiDB(_get_tidb_config())
await tidb_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
):
storage_instance.db = tidb_db
logger.info(f"Injected tidb_db to {storage_name}")
# Auto scan documents if enabled
if args.auto_scan_at_startup:
try:
new_files = doc_manager.scan_directory_for_new_files()
for file_path in new_files:
try:
await index_file(file_path)
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
ASCIIColors.info(
f"Indexed {len(new_files)} documents from {args.input_dir}"
)
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
yield
finally:
# Cleanup database connections
if postgres_db and hasattr(postgres_db, "pool"):
await postgres_db.pool.close()
logger.info("Closed PostgreSQL connection pool")
if oracle_db and hasattr(oracle_db, "pool"):
await oracle_db.pool.close()
logger.info("Closed Oracle connection pool")
if tidb_db and hasattr(tidb_db, "pool"):
await tidb_db.pool.close()
logger.info("Closed TiDB connection pool")
# Initialize FastAPI # Initialize FastAPI
app = FastAPI( app = FastAPI(
@@ -754,6 +845,92 @@ def create_app(args):
allow_headers=["*"], allow_headers=["*"],
) )
# Database configuration functions
def _get_postgres_config():
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config():
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config():
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Create the optional API key dependency # Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)
@@ -872,10 +1049,10 @@ def create_app(args):
if args.llm_binding == "lollms" or args.llm_binding == "ollama" if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {}, else {},
embedding_func=embedding_func, embedding_func=embedding_func,
kv_storage=rag_storage_config.KV_STORAGE, kv_storage=args.kv_storage,
graph_storage=rag_storage_config.GRAPH_STORAGE, graph_storage=args.graph_storage,
vector_storage=rag_storage_config.VECTOR_STORAGE, vector_storage=args.vector_storage,
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
@@ -903,10 +1080,10 @@ def create_app(args):
llm_model_max_async=args.max_async, llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens, llm_model_max_token_size=args.max_tokens,
embedding_func=embedding_func, embedding_func=embedding_func,
kv_storage=rag_storage_config.KV_STORAGE, kv_storage=args.kv_storage,
graph_storage=rag_storage_config.GRAPH_STORAGE, graph_storage=args.graph_storage,
vector_storage=rag_storage_config.VECTOR_STORAGE, vector_storage=args.vector_storage,
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
@@ -920,6 +1097,18 @@ def create_app(args):
namespace_prefix=args.namespace_prefix, namespace_prefix=args.namespace_prefix,
) )
# Collect all storage instances
storage_instances = [
("full_docs", rag.full_docs),
("text_chunks", rag.text_chunks),
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
("entities_vdb", rag.entities_vdb),
("relationships_vdb", rag.relationships_vdb),
("chunks_vdb", rag.chunks_vdb),
("doc_status", rag.doc_status),
("llm_response_cache", rag.llm_response_cache),
]
async def index_file(file_path: Union[str, Path]) -> None: async def index_file(file_path: Union[str, Path]) -> None:
"""Index all files inside the folder with support for multiple file formats """Index all files inside the folder with support for multiple file formats
@@ -1100,7 +1289,7 @@ def create_app(args):
mode=request.mode, mode=request.mode,
stream=request.stream, stream=request.stream,
only_need_context=request.only_need_context, only_need_context=request.only_need_context,
top_k=args.top_k, top_k=global_top_k,
), ),
) )
@@ -1142,7 +1331,7 @@ def create_app(args):
mode=request.mode, mode=request.mode,
stream=True, stream=True,
only_need_context=request.only_need_context, only_need_context=request.only_need_context,
top_k=args.top_k, top_k=global_top_k,
), ),
) )
@@ -1432,7 +1621,7 @@ def create_app(args):
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
# Add Ollama API routes # Add Ollama API routes
ollama_api = OllamaAPI(rag) ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api") app.include_router(ollama_api.router, prefix="/api")
@app.get("/documents", dependencies=[Depends(optional_api_key)]) @app.get("/documents", dependencies=[Depends(optional_api_key)])
@@ -1460,10 +1649,10 @@ def create_app(args):
"embedding_binding_host": args.embedding_binding_host, "embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model, "embedding_model": args.embedding_model,
"max_tokens": args.max_tokens, "max_tokens": args.max_tokens,
"kv_storage": rag_storage_config.KV_STORAGE, "kv_storage": args.kv_storage,
"doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE, "doc_status_storage": args.doc_status_storage,
"graph_storage": rag_storage_config.GRAPH_STORAGE, "graph_storage": args.graph_storage,
"vector_storage": rag_storage_config.VECTOR_STORAGE, "vector_storage": args.vector_storage,
}, },
} }

View File

@@ -148,9 +148,10 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
class OllamaAPI: class OllamaAPI:
def __init__(self, rag: LightRAG): def __init__(self, rag: LightRAG, top_k: int = 60):
self.rag = rag self.rag = rag
self.ollama_server_infos = ollama_server_infos self.ollama_server_infos = ollama_server_infos
self.top_k = top_k
self.router = APIRouter() self.router = APIRouter()
self.setup_routes() self.setup_routes()
@@ -381,7 +382,7 @@ class OllamaAPI:
"stream": request.stream, "stream": request.stream,
"only_need_context": False, "only_need_context": False,
"conversation_history": conversation_history, "conversation_history": conversation_history,
"top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50, "top_k": self.top_k,
} }
if ( if (

View File

@@ -75,8 +75,8 @@ class AGEStorage(BaseGraphStorage):
.replace("'", "\\'") .replace("'", "\\'")
) )
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'") HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
PORT = int(os.environ["AGE_POSTGRES_PORT"]) PORT = os.environ.get("AGE_POSTGRES_PORT", "8529")
self.graph_name = os.environ["AGE_GRAPH_NAME"] self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"

View File

@@ -1,4 +1,3 @@
import os
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Union
@@ -13,15 +12,17 @@ from lightrag.utils import logger
class ChromaVectorDBStorage(BaseVectorStorage): class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation.""" """ChromaDB vector storage implementation."""
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
try: try:
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
user_collection_settings = config.get("collection_settings", {}) user_collection_settings = config.get("collection_settings", {})
# Default HNSW index settings for ChromaDB # Default HNSW index settings for ChromaDB

View File

@@ -23,14 +23,17 @@ class FaissVectorDBStorage(BaseVectorStorage):
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
""" """
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Grab config values if available # Grab config values if available
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
# Where to save index file if you want persistent storage # Where to save index file if you want persistent storage
self._faiss_index_file = os.path.join( self._faiss_index_file = os.path.join(

View File

@@ -47,7 +47,9 @@ class GremlinStorage(BaseGraphStorage):
# All vertices will have graph={GRAPH} property, so that we can # All vertices will have graph={GRAPH} property, so that we can
# have several logical graphs for one source # have several logical graphs for one source
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"]) GRAPH = GremlinStorage._to_value_map(
os.environ.get("GREMLIN_GRAPH", "LightRAG")
)
self.graph_name = GRAPH self.graph_name = GRAPH

View File

@@ -5,16 +5,22 @@ from dataclasses import dataclass
import numpy as np import numpy as np
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("pymilvus"): if not pm.is_installed("pymilvus"):
pm.install("pymilvus") pm.install("pymilvus")
from pymilvus import MilvusClient from pymilvus import MilvusClient
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class MilvusVectorDBStorage(BaseVectorStorage): class MilvusVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs client: MilvusClient, collection_name: str, **kwargs
@@ -26,15 +32,37 @@ class MilvusVectorDBStorage(BaseVectorStorage):
) )
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._client = MilvusClient( self._client = MilvusClient(
uri=os.environ.get( uri=os.environ.get(
"MILVUS_URI", "MILVUS_URI",
os.path.join(self.global_config["working_dir"], "milvus_lite.db"), config.get(
"milvus",
"uri",
fallback=os.path.join(
self.global_config["working_dir"], "milvus_lite.db"
),
),
),
user=os.environ.get(
"MILVUS_USER", config.get("milvus", "user", fallback=None)
),
password=os.environ.get(
"MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)
),
token=os.environ.get(
"MILVUS_TOKEN", config.get("milvus", "token", fallback=None)
),
db_name=os.environ.get(
"MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)
), ),
user=os.environ.get("MILVUS_USER", ""),
password=os.environ.get("MILVUS_PASSWORD", ""),
token=os.environ.get("MILVUS_TOKEN", ""),
db_name=os.environ.get("MILVUS_DB_NAME", ""),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorage.create_collection_if_not_exist( MilvusVectorDBStorage.create_collection_if_not_exist(
@@ -85,7 +113,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
data=embedding, data=embedding,
limit=top_k, limit=top_k,
output_fields=list(self.meta_fields), output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}}, search_params={
"metric_type": "COSINE",
"params": {"radius": self.cosine_better_than_threshold},
},
) )
print(results) print(results)
return [ return [

View File

@@ -1,8 +1,8 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import pipmaster as pm import pipmaster as pm
import configparser
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
if not pm.is_installed("pymongo"): if not pm.is_installed("pymongo"):
@@ -12,7 +12,6 @@ if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
@@ -27,13 +26,27 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
client = MongoClient( client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
)
) )
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace) self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as KV {self.namespace}") logger.info(f"Use MongoDB as KV {self.namespace}")
@@ -173,10 +186,25 @@ class MongoGraphStorage(BaseGraphStorage):
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self.client = AsyncIOMotorClient( self.client = AsyncIOMotorClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
) )
self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")] self.db = self.client[
self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")] os.environ.get(
"MONGO_DATABASE",
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
)
]
self.collection = self.db[
os.environ.get(
"MONGO_KG_COLLECTION",
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
)
]
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------

View File

@@ -73,16 +73,19 @@ from lightrag.base import (
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize lock only for file operations
self._save_lock = asyncio.Lock() self._save_lock = asyncio.Lock()
# Use global config value if specified, otherwise use default # Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -139,9 +142,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def query(self, query: str, top_k=5): async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
logger.info(
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
)
results = self._client.query( results = self._client.query(
query=embedding, query=embedding,
top_k=top_k, top_k=top_k,

View File

@@ -5,6 +5,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict from typing import Any, Union, Tuple, List, Dict
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
pm.install("neo4j") pm.install("neo4j")
@@ -28,6 +29,10 @@ from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -42,13 +47,22 @@ class Neo4JStorage(BaseGraphStorage):
) )
self._driver = None self._driver = None
self._driver_lock = asyncio.Lock() self._driver_lock = asyncio.Lock()
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"] URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)]
PASSWORD = os.environ["NEO4J_PASSWORD"] USERNAME = os.environ[
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
]
PASSWORD = os.environ[
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
]
MAX_CONNECTION_POOL_SIZE = os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=800),
)
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
) )
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD) URI, auth=(USERNAME, PASSWORD)
) )

View File

@@ -1,6 +1,5 @@
import array import array
import asyncio import asyncio
import os
# import html # import html
# import os # import os
@@ -172,8 +171,8 @@ class OracleDB:
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db # db instance must be injected before use
db: OracleDB = None # db: OracleDB
meta_fields = None meta_fields = None
def __post_init__(self): def __post_init__(self):
@@ -318,16 +317,18 @@ class OracleKVStorage(BaseKVStorage):
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db # db instance must be injected before use
db: OracleDB = None # db: OracleDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据""" """向向量数据库中插入数据"""
@@ -361,7 +362,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
"""基于Oracle的图存储模块""" # db instance must be injected before use
# db: OracleDB
def __post_init__(self): def __post_init__(self):
"""从graphml文件加载图""" """从graphml文件加载图"""

View File

@@ -177,7 +177,8 @@ class PostgreSQLDB:
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None # db instance must be injected before use
# db: PostgreSQLDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -296,16 +297,19 @@ class PGKVStorage(BaseKVStorage):
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) # db instance must be injected before use
db: PostgreSQLDB = None # db: PostgreSQLDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
def _upsert_chunks(self, item: dict): def _upsert_chunks(self, item: dict):
try: try:
@@ -416,20 +420,14 @@ class PGVectorStorage(BaseVectorStorage):
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
"""PostgreSQL implementation of document status storage""" # db instance must be injected before use
# db: PostgreSQLDB
db: PostgreSQLDB = None
def __post_init__(self):
pass
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, data: set[str]) -> set[str]:
"""Return keys that don't exist in storage""" """Return keys that don't exist in storage"""
keys = ",".join([f"'{_id}'" for _id in data]) keys = ",".join([f"'{_id}'" for _id in data])
sql = ( sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})" result = await self.db.query(sql, multirows=True)
)
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None: if result is None:
return set(data) return set(data)
@@ -585,19 +583,15 @@ class PGGraphQueryException(Exception):
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = None # db instance must be injected before use
# db: PostgreSQLDB
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print("no preloading of graph with AGE in production") print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func): def __post_init__(self):
super().__init__( self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self.graph_name = os.environ["AGE_GRAPH_NAME"]
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
@@ -1137,7 +1131,7 @@ TABLES = {
"ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS ( "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
workspace varchar(255) NOT NULL, workspace varchar(255) NOT NULL,
id varchar(255) NOT NULL, id varchar(255) NOT NULL,
content TEXT, content TEXT NULL,
content_summary varchar(255) NULL, content_summary varchar(255) NULL,
content_length int4 NULL, content_length int4 NULL,
chunks_count int4 NULL, chunks_count int4 NULL,

View File

@@ -1,138 +0,0 @@
import asyncio
import sys
import os
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import asyncpg
import psycopg
from psycopg_pool import AsyncConnectionPool
from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from ..namespace import NameSpace
DB = "rag"
USER = "rag"
PASSWORD = "rag"
HOST = "localhost"
PORT = "15432"
os.environ["AGE_GRAPH_NAME"] = "dickens"
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def get_pool():
return await asyncpg.create_pool(
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
min_size=10,
max_size=10,
max_queries=5000,
max_inactive_connection_lifetime=300.0,
)
async def main1():
connection_string = (
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
)
pool = AsyncConnectionPool(connection_string, open=False)
await pool.open()
try:
conn = await pool.getconn(timeout=10)
async with conn.cursor() as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute("SELECT create_graph('dickens-2')")
await conn.commit()
print("create_graph success")
except (
psycopg.errors.InvalidSchemaName,
psycopg.errors.UniqueViolation,
):
print("create_graph already exists")
await conn.rollback()
finally:
pass
db = PostgreSQLDB(
config={
"host": "localhost",
"port": 15432,
"user": "rag",
"password": "rag",
"database": "r1",
}
)
async def query_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)
graph.db = db
res = await graph.get_node('"A CHRISTMAS CAROL"')
print("Node is: ", res)
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
print("Edge is: ", res)
res = await graph.get_node_edges('"SCROOGE"')
print("Node Edges are: ", res)
async def create_edge_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)
graph.db = db
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
await graph.upsert_edge(
'"THE CRATCHITS"',
'"THE GIRLS"',
edge_data={
"weight": 7.0,
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
"keywords": '"family, collective effort"',
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
},
)
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
print("Edge is: ", res)
async def main():
pool = await get_pool()
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
async with pool.acquire() as conn:
try:
await conn.execute(
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
)
except asyncpg.exceptions.InvalidSchemaNameError:
print("create_graph already exists")
# stmt = await conn.prepare(sql)
row = await conn.fetch(sql)
print("row is: ", row)
row = await conn.fetchrow("select '100'::int + 200 as result")
print(row) # <Record result=300>
if __name__ == "__main__":
asyncio.run(query_with_age())

View File

@@ -5,11 +5,10 @@ from dataclasses import dataclass
import numpy as np import numpy as np
import hashlib import hashlib
import uuid import uuid
from ..utils import logger from ..utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("qdrant_client"): if not pm.is_installed("qdrant_client"):
pm.install("qdrant_client") pm.install("qdrant_client")
@@ -17,6 +16,10 @@ if not pm.is_installed("qdrant_client"):
from qdrant_client import QdrantClient, models from qdrant_client import QdrantClient, models
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
def compute_mdhash_id_for_qdrant( def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple" content: str, prefix: str = "", style: str = "simple"
) -> str: ) -> str:
@@ -47,6 +50,8 @@ def compute_mdhash_id_for_qdrant(
@dataclass @dataclass
class QdrantVectorDBStorage(BaseVectorStorage): class QdrantVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs client: QdrantClient, collection_name: str, **kwargs
@@ -56,9 +61,21 @@ class QdrantVectorDBStorage(BaseVectorStorage):
client.create_collection(collection_name, **kwargs) client.create_collection(collection_name, **kwargs)
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._client = QdrantClient( self._client = QdrantClient(
url=os.environ.get("QDRANT_URL"), url=os.environ.get(
api_key=os.environ.get("QDRANT_API_KEY", None), "QDRANT_URL", config.get("qdrant", "uri", fallback=None)
),
api_key=os.environ.get(
"QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)
),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist( QdrantVectorDBStorage.create_collection_if_not_exist(
@@ -122,4 +139,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
limit=top_k, limit=top_k,
with_payload=True, with_payload=True,
) )
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] logger.debug(f"query result: {results}")
# 添加余弦相似度过滤
filtered_results = [
dp for dp in results if dp.score >= self.cosine_better_than_threshold
]
return [
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
]

View File

@@ -3,6 +3,7 @@ from typing import Any, Union
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("redis"): if not pm.is_installed("redis"):
pm.install("redis") pm.install("redis")
@@ -14,10 +15,16 @@ from lightrag.base import BaseKVStorage
import json import json
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class RedisKVStorage(BaseKVStorage): class RedisKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379") redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
self._redis = Redis.from_url(redis_url, decode_responses=True) self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}") logger.info(f"Use Redis as KV {self.namespace}")

View File

@@ -101,7 +101,9 @@ class TiDB:
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
# should pass db object to self.db # db instance must be injected before use
# db: TiDB
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -208,18 +210,22 @@ class TiDBKVStorage(BaseKVStorage):
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) # db instance must be injected before use
# db: TiDB
cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
async def query(self, query: str, top_k: int) -> list[dict]: async def query(self, query: str, top_k: int) -> list[dict]:
"""Search from tidb vector""" """Search from tidb vector"""
@@ -329,6 +335,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: TiDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
import configparser
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@@ -36,6 +37,111 @@ from .utils import (
) )
from .types import KnowledgeGraph from .types import KnowledgeGraph
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorge",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"],
"required_methods": ["get_pending_docs"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorge": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = { STORAGES = {
"NetworkXStorage": ".kg.networkx_impl", "NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl", "JsonKVStorage": ".kg.json_kv_impl",
@@ -140,6 +246,9 @@ class LightRAG:
graph_storage: str = field(default="NetworkXStorage") graph_storage: str = field(default="NetworkXStorage")
"""Storage backend for knowledge graphs.""" """Storage backend for knowledge graphs."""
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Logging # Logging
current_log_level = logger.level current_log_level = logger.level
log_level: int = field(default=current_log_level) log_level: int = field(default=current_log_level)
@@ -236,9 +345,6 @@ class LightRAG:
convert_response_to_json convert_response_to_json
) )
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Custom Chunking Function # Custom Chunking Function
chunking_func: Callable[ chunking_func: Callable[
[ [
@@ -252,6 +358,46 @@ class LightRAG:
list[dict[str, Any]], list[dict[str, Any]],
] = chunking_by_token_size ] = chunking_by_token_size
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
def __post_init__(self): def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True) os.makedirs(self.log_dir, exist_ok=True)
log_file = os.path.join(self.log_dir, "lightrag.log") log_file = os.path.join(self.log_dir, "lightrag.log")
@@ -263,6 +409,29 @@ class LightRAG:
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
# Verify storage implementation compatibility and environment variables
storage_configs = [
("KV_STORAGE", self.kv_storage),
("VECTOR_STORAGE", self.vector_storage),
("GRAPH_STORAGE", self.graph_storage),
("DOC_STATUS_STORAGE", self.doc_status_storage),
]
for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility
self.verify_storage_implementation(storage_type, storage_name)
# Check environment variables
self.check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields
default_vector_db_kwargs = {
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
}
self.vector_db_storage_cls_kwargs = {
**default_vector_db_kwargs,
**self.vector_db_storage_cls_kwargs,
}
# show config # show config
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
@@ -296,10 +465,8 @@ class LightRAG:
self.graph_storage_cls, global_config=global_config self.graph_storage_cls, global_config=global_config
) )
self.json_doc_status_storage = self.key_string_value_json_storage_cls( # Initialize document status storage
namespace=self.namespace_prefix + "json_doc_status_storage", self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
embedding_func=None,
)
self.llm_response_cache = self.key_string_value_json_storage_cls( self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=make_namespace( namespace=make_namespace(
@@ -308,9 +475,6 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
####
# add embedding func by walter
####
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
@@ -329,9 +493,6 @@ class LightRAG:
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
####
# add embedding func by walter over
####
self.entities_vdb = self.vector_db_storage_cls( self.entities_vdb = self.vector_db_storage_cls(
namespace=make_namespace( namespace=make_namespace(
@@ -354,6 +515,14 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
# Initialize document status storage
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
# What's for, Is this nessisary ?
if self.llm_response_cache and hasattr( if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config" self.llm_response_cache, "global_config"
): ):
@@ -374,14 +543,6 @@ class LightRAG:
) )
) )
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
async def get_graph_labels(self): async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels() text = await self.chunk_entity_relation_graph.get_all_labels()
return text return text
@@ -399,7 +560,8 @@ class LightRAG:
return storage_class return storage_class
def set_storage_client(self, db_client): def set_storage_client(self, db_client):
# Now only tested on Oracle Database # Deprecated, seting correct value to *_storage of LightRAG insteaded
# Inject db to storage implementation (only tested on Oracle Database)
for storage in [ for storage in [
self.vector_db_storage_cls, self.vector_db_storage_cls,
self.graph_storage_cls, self.graph_storage_cls,

View File

@@ -1055,6 +1055,9 @@ async def _get_node_data(
query_param: QueryParam, query_param: QueryParam,
): ):
# get similar entities # get similar entities
logger.info(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
)
results = await entities_vdb.query(query, top_k=query_param.top_k) results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results): if not len(results):
return "", "", "" return "", "", ""
@@ -1270,6 +1273,9 @@ async def _get_edge_data(
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
): ):
logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k) results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results): if not len(results):

View File

@@ -416,7 +416,13 @@ async def get_best_cached_response(
if best_similarity > similarity_threshold: if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided # If LLM check is enabled and all required parameters are provided
if use_llm_check and llm_func and original_prompt and best_prompt: if (
use_llm_check
and llm_func
and original_prompt
and best_prompt
and best_response is not None
):
compare_prompt = PROMPTS["similarity_check"].format( compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt original_prompt=original_prompt, cached_prompt=best_prompt
) )
@@ -430,7 +436,9 @@ async def get_best_cached_response(
best_similarity = llm_similarity best_similarity = llm_similarity
if best_similarity < similarity_threshold: if best_similarity < similarity_threshold:
log_data = { log_data = {
"event": "llm_check_cache_rejected", "event": "cache_rejected_by_llm",
"type": cache_type,
"mode": mode,
"original_question": original_prompt[:100] + "..." "original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100 if len(original_prompt) > 100
else original_prompt, else original_prompt,
@@ -440,7 +448,8 @@ async def get_best_cached_response(
"similarity_score": round(best_similarity, 4), "similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold, "threshold": similarity_threshold,
} }
logger.info(json.dumps(log_data, ensure_ascii=False)) logger.debug(json.dumps(log_data, ensure_ascii=False))
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
return None return None
except Exception as e: # Catch all possible exceptions except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}") logger.warning(f"LLM similarity check failed: {e}")
@@ -451,12 +460,13 @@ async def get_best_cached_response(
) )
log_data = { log_data = {
"event": "cache_hit", "event": "cache_hit",
"type": cache_type,
"mode": mode, "mode": mode,
"similarity": round(best_similarity, 4), "similarity": round(best_similarity, 4),
"cache_id": best_cache_id, "cache_id": best_cache_id,
"original_prompt": prompt_display, "original_prompt": prompt_display,
} }
logger.info(json.dumps(log_data, ensure_ascii=False)) logger.debug(json.dumps(log_data, ensure_ascii=False))
return best_response return best_response
return None return None
@@ -534,19 +544,24 @@ async def handle_cache(
cache_type=cache_type, cache_type=cache_type,
) )
if best_cached_response is not None: if best_cached_response is not None:
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
return best_cached_response, None, None, None return best_cached_response, None, None, None
else: else:
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
return None, quantized, min_val, max_val return None, quantized, min_val, max_val
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False # For default mode or is_embedding_cache_enabled is False, use regular cache
# Use regular cache # default mode is for extract_entities or naive query
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else: else:
mode_cache = await hashing_kv.get_by_id(mode) or {} mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache: if args_hash in mode_cache:
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
return mode_cache[args_hash]["return"], None, None, None return mode_cache[args_hash]["return"], None, None, None
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
return None, None, None, None return None, None, None, None