Merge pull request #744 from danielaskdd/select-datastore-in-api-server

Add datastore selection feature for API Server
This commit is contained in:
zrguo
2025-02-13 20:04:16 +08:00
committed by GitHub
23 changed files with 897 additions and 444 deletions

View File

@@ -1,12 +1,30 @@
# Server Configuration
HOST=0.0.0.0
PORT=9621
### Server Configuration
#HOST=0.0.0.0
#PORT=9621
#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# Directory Configuration
WORKING_DIR=/app/data/rag_storage
INPUT_DIR=/app/data/inputs
### Optional SSL Configuration
#SSL=true
#SSL_CERTFILE=/path/to/cert.pem
#SSL_KEYFILE=/path/to/key.pem
# RAG Configuration
### Security (empty for no api-key is needed)
# LIGHTRAG_API_KEY=your-secure-api-key-here
### Directory Configuration
# WORKING_DIR=./rag_storage
# INPUT_DIR=./inputs
### Logging level
LOG_LEVEL=INFO
### Optional Timeout
TIMEOUT=300
# Ollama Emulating Model Tag
# OLLAMA_EMULATING_MODEL_TAG=latest
### RAG Configuration
MAX_ASYNC=4
MAX_TOKENS=32768
EMBEDDING_DIM=1024
@@ -14,56 +32,42 @@ MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3
#CHUNK_SIZE=1200
#CHUNK_OVERLAP_SIZE=100
#COSINE_THRESHOLD=0.4 # 0.2 while not running API server
#TOP_K=50 # 60 while not running API server
#COSINE_THRESHOLD=0.2
#TOP_K=60
# LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
# Ollama example
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
### Ollama example
LLM_BINDING=ollama
LLM_BINDING_HOST=http://host.docker.internal:11434
LLM_MODEL=mistral-nemo:latest
# OpenAI alike example
### OpenAI alike example
# LLM_BINDING=openai
# LLM_MODEL=deepseek-chat
# LLM_BINDING_HOST=https://api.deepseek.com
# LLM_BINDING_API_KEY=your_api_key
# for OpenAI LLM (LLM_BINDING_API_KEY take priority)
### for OpenAI LLM (LLM_BINDING_API_KEY take priority)
# OPENAI_API_KEY=your_api_key
# Lollms example
### Lollms example
# LLM_BINDING=lollms
# LLM_BINDING_HOST=http://host.docker.internal:9600
# LLM_MODEL=mistral-nemo:latest
# Embedding Configuration (Use valid host. For local services, you can use host.docker.internal)
### Embedding Configuration (Use valid host. For local services, you can use host.docker.internal)
# Ollama example
EMBEDDING_BINDING=ollama
EMBEDDING_BINDING_HOST=http://host.docker.internal:11434
EMBEDDING_MODEL=bge-m3:latest
# Lollms example
### Lollms example
# EMBEDDING_BINDING=lollms
# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
# EMBEDDING_MODEL=bge-m3:latest
# Security (empty for no key)
LIGHTRAG_API_KEY=your-secure-api-key-here
# Logging
LOG_LEVEL=INFO
# Optional SSL Configuration
#SSL=true
#SSL_CERTFILE=/path/to/cert.pem
#SSL_KEYFILE=/path/to/key.pem
# Optional Timeout
#TIMEOUT=30
# Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
### Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
# AZURE_OPENAI_API_VERSION=2024-08-01-preview
# AZURE_OPENAI_DEPLOYMENT=gpt-4o
# AZURE_OPENAI_API_KEY=myapikey
@@ -72,6 +76,57 @@ LOG_LEVEL=INFO
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
# AZURE_EMBEDDING_API_VERSION=2023-05-15
### Data storage selection
# LIGHTRAG_KV_STORAGE=PGKVStorage
# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage
# LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
# Ollama Emulating Model Tag
# OLLAMA_EMULATING_MODEL_TAG=latest
### Oracle Database Configuration
ORACLE_DSN=localhost:1521/XEPDB1
ORACLE_USER=your_username
ORACLE_PASSWORD='your_password'
ORACLE_CONFIG_DIR=/path/to/oracle/config
#ORACLE_WALLET_LOCATION=/path/to/wallet # optional
#ORACLE_WALLET_PASSWORD='your_password' # optional
#ORACLE_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
### TiDB Configuration
TIDB_HOST=localhost
TIDB_PORT=4000
TIDB_USER=your_username
TIDB_PASSWORD='your_password'
TIDB_DATABASE=your_database
#TIDB_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
### PostgreSQL Configuration
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
POSTGRES_USER=your_username
POSTGRES_PASSWORD='your_password'
POSTGRES_DATABASE=your_database
#POSTGRES_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
AGE_POSTGRES_DB=
AGE_POSTGRES_USER=
AGE_POSTGRES_PASSWORD=
AGE_POSTGRES_HOST=
# AGE_POSTGRES_PORT=8529
# AGE Graph Name(apply to PostgreSQL and independent AGM)
# AGE_GRAPH_NAME=lightrag # deprecated, use NAME_SPACE_PREFIX instead
### Neo4j Configuration
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
NEO4J_USERNAME=neo4j
NEO4J_PASSWORD='your_password'
### MongoDB Configuration
MONGODB_URI=mongodb://root:root@localhost:27017/
MONGODB_DATABASE=LightRAG
MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
### Qdrant
QDRANT_URL=http://localhost:16333
QDRANT_API_KEY=your-api-key # 可选

View File

@@ -13,3 +13,28 @@ uri=redis://localhost:6379/1
[qdrant]
uri = http://localhost:16333
[oracle]
dsn = localhost:1521/XEPDB1
user = your_username
password = your_password
config_dir = /path/to/oracle/config
wallet_location = /path/to/wallet # 可选
wallet_password = your_wallet_password # 可选
workspace = default # 可选,默认为default
[tidb]
host = localhost
port = 4000
user = your_username
password = your_password
database = your_database
workspace = default # 可选,默认为default
[postgres]
host = localhost
port = 5432
user = your_username
password = your_password
database = your_database
workspace = default # 可选,默认为default

View File

@@ -103,66 +103,23 @@ After starting the lightrag-server, you can add an Ollama-type connection in the
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM.
Default `TOP_K` is set to `60`. Default `COSINE_THRESHOLD` are set to `0.2`.
### Environment Variables
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables:
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. A sample file `.env.example` is provided for your convenience.
```env
# Server Configuration
HOST=0.0.0.0
PORT=9621
### Config.ini
# Directory Configuration
WORKING_DIR=/app/data/rag_storage
INPUT_DIR=/app/data/inputs
# RAG Configuration
MAX_ASYNC=4
MAX_TOKENS=32768
EMBEDDING_DIM=1024
MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3
#CHUNK_SIZE=1200
#CHUNK_OVERLAP_SIZE=100
#COSINE_THRESHOLD=0.4
#TOP_K=50
# LLM Configuration
LLM_BINDING=ollama
LLM_BINDING_HOST=http://localhost:11434
LLM_MODEL=mistral-nemo:latest
# must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms)
OPENAI_API_KEY=you_api_key
# Embedding Configuration
EMBEDDING_BINDING=ollama
EMBEDDING_BINDING_HOST=http://localhost:11434
EMBEDDING_MODEL=bge-m3:latest
# Security
#LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG
# Logging
LOG_LEVEL=INFO
# Optional SSL Configuration
#SSL=true
#SSL_CERTFILE=/path/to/cert.pem
#SSL_KEYFILE=/path/to/key.pem
# Optional Timeout
#TIMEOUT=30
```
Datastorage configuration can be also set by config.ini. A sample file `config.ini.example` is provided for your convenience.
### Configuration Priority
The configuration values are loaded in the following order (highest priority first):
1. Command-line arguments
2. Environment variables
3. Default values
3. Config.ini
4. Defaul values
For example:
```bash
@@ -173,7 +130,69 @@ python lightrag.py --port 8080
PORT=7000 python lightrag.py
```
#### LightRag Server Options
> Best practices: you can set your database setting in Config.ini while testing, and you use .env for production.
### Storage Types Supported
LightRAG uses 4 types of storage for difference purposes:
* KV_STORAGEllm response cache, text chunks, document information
* VECTOR_STORAGEentities vectors, relation vectors, chunks vectors
* GRAPH_STORAGEentity relation graph
* DOC_STATUS_STORAGEdocuments indexing status
Each storage type have servals implementations:
* KV_STORAGE supported implement-name
```
JsonKVStorage JsonFile(default)
MongoKVStorage MogonDB
RedisKVStorage Redis
TiDBKVStorage TiDB
PGKVStorage Postgres
OracleKVStorage Oracle
```
* GRAPH_STORAGE supported implement-name
```
NetworkXStorage NetworkX(defualt)
Neo4JStorage Neo4J
MongoGraphStorage MongoDB
TiDBGraphStorage TiDB
AGEStorage AGE
GremlinStorage Gremlin
PGGraphStorage Postgres
OracleGraphStorage Postgres
```
* VECTOR_STORAGE supported implement-name
```
NanoVectorDBStorage NanoVector(default)
MilvusVectorDBStorge Milvus
ChromaVectorDBStorage Chroma
TiDBVectorDBStorage TiDB
PGVectorStorage Postgres
FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant
OracleVectorDBStorag Oracle
```
* DOC_STATUS_STORAGEsupported implement-name
```
JsonDocStatusStorage JsonFile(default)
PGDocStatusStorage Postgres
MongoDocStatusStorage MongoDB
```
### How Select Storage Implementation
You can select storage implementation by enviroment variables or command line arguments. You can not change storage implementation selection after you add documents to LightRAG. Data migration from one storage implementation to anthor is not supported yet. For further information please read the sample env file or config.ini file.
### LightRag API Server Comand Line Options
| Parameter | Default | Description |
|-----------|---------|-------------|
@@ -200,6 +219,10 @@ PORT=7000 python lightrag.py
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
| --kv-storage | JsonKVStorage | implement-name of KV_STORAGE |
| --graph-storage | NetworkXStorage | implement-name of GRAPH_STORAGE |
| --vector-storage | NanoVectorDBStorage | implement-name of VECTOR_STORAGE |
| --doc-status-storage | JsonDocStatusStorage | implement-name of DOC_STATUS_STORAGE |
### Example Usage
@@ -343,6 +366,14 @@ curl -X POST "http://localhost:9621/documents/scan" --max-time 1800
> Ajust max-time according to the estimated index time for all new files.
#### DELETE /documents
Clear all documents from the RAG system.
```bash
curl -X DELETE "http://localhost:9621/documents"
```
### Ollama Emulation Endpoints
#### GET /api/version
@@ -372,14 +403,6 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
#### DELETE /documents
Clear all documents from the RAG system.
```bash
curl -X DELETE "http://localhost:9621/documents"
```
### Utility Endpoints
#### GET /health

View File

@@ -1 +1 @@
__api_version__ = "1.0.4"
__api_version__ = "1.0.5"

View File

@@ -26,7 +26,6 @@ import shutil
import aiofiles
from ascii_colors import trace_exception, ASCIIColors
import sys
import configparser
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
@@ -34,25 +33,47 @@ from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm
from dotenv import load_dotenv
import configparser
from lightrag.utils import logger
from .ollama_api import (
OllamaAPI,
)
from .ollama_api import ollama_server_infos
from ..kg.postgres_impl import (
PostgreSQLDB,
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
)
from ..kg.oracle_impl import (
OracleDB,
OracleKVStorage,
OracleVectorDBStorage,
OracleGraphStorage,
)
from ..kg.tidb_impl import (
TiDB,
TiDBKVStorage,
TiDBVectorDBStorage,
TiDBGraphStorage,
)
# Load environment variables
load_dotenv(override=True)
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
class RAGStorageConfig:
class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
# Initialize rag storage config
rag_storage_config = RAGStorageConfig()
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
@@ -80,61 +101,6 @@ def estimate_tokens(text: str) -> int:
return int(tokens)
# read config.ini
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Redis config
redis_uri = config.get("redis", "uri", fallback=None)
if redis_uri:
os.environ["REDIS_URI"] = redis_uri
rag_storage_config.KV_STORAGE = "RedisKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
# Neo4j config
neo4j_uri = config.get("neo4j", "uri", fallback=None)
neo4j_username = config.get("neo4j", "username", fallback=None)
neo4j_password = config.get("neo4j", "password", fallback=None)
if neo4j_uri:
os.environ["NEO4J_URI"] = neo4j_uri
os.environ["NEO4J_USERNAME"] = neo4j_username
os.environ["NEO4J_PASSWORD"] = neo4j_password
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
# Milvus config
milvus_uri = config.get("milvus", "uri", fallback=None)
milvus_user = config.get("milvus", "user", fallback=None)
milvus_password = config.get("milvus", "password", fallback=None)
milvus_db_name = config.get("milvus", "db_name", fallback=None)
if milvus_uri:
os.environ["MILVUS_URI"] = milvus_uri
os.environ["MILVUS_USER"] = milvus_user
os.environ["MILVUS_PASSWORD"] = milvus_password
os.environ["MILVUS_DB_NAME"] = milvus_db_name
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorage"
# Qdrant config
qdrant_uri = config.get("qdrant", "uri", fallback=None)
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
if qdrant_uri:
os.environ["QDRANT_URL"] = qdrant_uri
if qdrant_api_key:
os.environ["QDRANT_API_KEY"] = qdrant_api_key
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
# MongoDB config
mongo_uri = config.get("mongodb", "uri", fallback=None)
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
if mongo_graph:
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
@@ -247,6 +213,16 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.top_k}")
# System Configuration
ASCIIColors.magenta("\n💾 Storage Configuration:")
ASCIIColors.white(" ├─ KV Storage: ", end="")
ASCIIColors.yellow(f"{args.kv_storage}")
ASCIIColors.white(" ├─ Vector Storage: ", end="")
ASCIIColors.yellow(f"{args.vector_storage}")
ASCIIColors.white(" ├─ Graph Storage: ", end="")
ASCIIColors.yellow(f"{args.graph_storage}")
ASCIIColors.white(" └─ Document Status Storage: ", end="")
ASCIIColors.yellow(f"{args.doc_status_storage}")
ASCIIColors.magenta("\n🛠️ System Configuration:")
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
@@ -344,6 +320,35 @@ def parse_args() -> argparse.Namespace:
description="LightRAG FastAPI Server with separate working and input directories"
)
parser.add_argument(
"--kv-storage",
default=get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
),
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
)
parser.add_argument(
"--doc-status-storage",
default=get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
),
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
)
parser.add_argument(
"--graph-storage",
default=get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
),
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
)
parser.add_argument(
"--vector-storage",
default=get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
),
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
)
# Bindings configuration
parser.add_argument(
"--llm-binding",
@@ -528,13 +533,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--top-k",
type=int,
default=get_env_value("TOP_K", 50, int),
help="Number of most similar results to return (default: from env or 50)",
default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
default=get_env_value("COSINE_THRESHOLD", 0.4, float),
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)",
)
@@ -667,7 +672,14 @@ def get_api_key_dependency(api_key: Optional[str]):
return api_key_auth
# Global configuration
global_top_k = 60 # default value
def create_app(args):
global global_top_k
global_top_k = args.top_k # save top_k from args
# Verify that bindings are correctly setup
if args.llm_binding not in [
"lollms",
@@ -713,25 +725,104 @@ def create_app(args):
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Startup logic
if args.auto_scan_at_startup:
try:
new_files = doc_manager.scan_directory_for_new_files()
for file_path in new_files:
try:
await index_file(file_path)
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
# Initialize database connections
postgres_db = None
oracle_db = None
tidb_db = None
ASCIIColors.info(
f"Indexed {len(new_files)} documents from {args.input_dir}"
try:
# Check if PostgreSQL is needed
if any(
isinstance(
storage_instance,
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
)
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
yield
# Cleanup logic (if needed)
pass
for _, storage_instance in storage_instances
):
postgres_db = PostgreSQLDB(_get_postgres_config())
await postgres_db.initdb()
await postgres_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(
PGKVStorage,
PGVectorStorage,
PGGraphStorage,
PGDocStatusStorage,
),
):
storage_instance.db = postgres_db
logger.info(f"Injected postgres_db to {storage_name}")
# Check if Oracle is needed
if any(
isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
)
for _, storage_instance in storage_instances
):
oracle_db = OracleDB(_get_oracle_config())
await oracle_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
):
storage_instance.db = oracle_db
logger.info(f"Injected oracle_db to {storage_name}")
# Check if TiDB is needed
if any(
isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
)
for _, storage_instance in storage_instances
):
tidb_db = TiDB(_get_tidb_config())
await tidb_db.check_tables()
for storage_name, storage_instance in storage_instances:
if isinstance(
storage_instance,
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
):
storage_instance.db = tidb_db
logger.info(f"Injected tidb_db to {storage_name}")
# Auto scan documents if enabled
if args.auto_scan_at_startup:
try:
new_files = doc_manager.scan_directory_for_new_files()
for file_path in new_files:
try:
await index_file(file_path)
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
ASCIIColors.info(
f"Indexed {len(new_files)} documents from {args.input_dir}"
)
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
yield
finally:
# Cleanup database connections
if postgres_db and hasattr(postgres_db, "pool"):
await postgres_db.pool.close()
logger.info("Closed PostgreSQL connection pool")
if oracle_db and hasattr(oracle_db, "pool"):
await oracle_db.pool.close()
logger.info("Closed Oracle connection pool")
if tidb_db and hasattr(tidb_db, "pool"):
await tidb_db.pool.close()
logger.info("Closed TiDB connection pool")
# Initialize FastAPI
app = FastAPI(
@@ -754,6 +845,92 @@ def create_app(args):
allow_headers=["*"],
)
# Database configuration functions
def _get_postgres_config():
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config():
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config():
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
@@ -872,10 +1049,10 @@ def create_app(args):
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {},
embedding_func=embedding_func,
kv_storage=rag_storage_config.KV_STORAGE,
graph_storage=rag_storage_config.GRAPH_STORAGE,
vector_storage=rag_storage_config.VECTOR_STORAGE,
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
@@ -903,10 +1080,10 @@ def create_app(args):
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
embedding_func=embedding_func,
kv_storage=rag_storage_config.KV_STORAGE,
graph_storage=rag_storage_config.GRAPH_STORAGE,
vector_storage=rag_storage_config.VECTOR_STORAGE,
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
@@ -920,6 +1097,18 @@ def create_app(args):
namespace_prefix=args.namespace_prefix,
)
# Collect all storage instances
storage_instances = [
("full_docs", rag.full_docs),
("text_chunks", rag.text_chunks),
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
("entities_vdb", rag.entities_vdb),
("relationships_vdb", rag.relationships_vdb),
("chunks_vdb", rag.chunks_vdb),
("doc_status", rag.doc_status),
("llm_response_cache", rag.llm_response_cache),
]
async def index_file(file_path: Union[str, Path]) -> None:
"""Index all files inside the folder with support for multiple file formats
@@ -1100,7 +1289,7 @@ def create_app(args):
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
top_k=args.top_k,
top_k=global_top_k,
),
)
@@ -1142,7 +1331,7 @@ def create_app(args):
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
top_k=args.top_k,
top_k=global_top_k,
),
)
@@ -1432,7 +1621,7 @@ def create_app(args):
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
# Add Ollama API routes
ollama_api = OllamaAPI(rag)
ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api")
@app.get("/documents", dependencies=[Depends(optional_api_key)])
@@ -1460,10 +1649,10 @@ def create_app(args):
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"kv_storage": rag_storage_config.KV_STORAGE,
"doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
"graph_storage": rag_storage_config.GRAPH_STORAGE,
"vector_storage": rag_storage_config.VECTOR_STORAGE,
"kv_storage": args.kv_storage,
"doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage,
"vector_storage": args.vector_storage,
},
}

View File

@@ -148,9 +148,10 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
class OllamaAPI:
def __init__(self, rag: LightRAG):
def __init__(self, rag: LightRAG, top_k: int = 60):
self.rag = rag
self.ollama_server_infos = ollama_server_infos
self.top_k = top_k
self.router = APIRouter()
self.setup_routes()
@@ -381,7 +382,7 @@ class OllamaAPI:
"stream": request.stream,
"only_need_context": False,
"conversation_history": conversation_history,
"top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50,
"top_k": self.top_k,
}
if (

View File

@@ -75,8 +75,8 @@ class AGEStorage(BaseGraphStorage):
.replace("'", "\\'")
)
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
PORT = int(os.environ["AGE_POSTGRES_PORT"])
self.graph_name = os.environ["AGE_GRAPH_NAME"]
PORT = os.environ.get("AGE_POSTGRES_PORT", "8529")
self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"

View File

@@ -1,4 +1,3 @@
import os
import asyncio
from dataclasses import dataclass
from typing import Union
@@ -13,15 +12,17 @@ from lightrag.utils import logger
class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation."""
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
cosine_better_than_threshold: float = None
def __post_init__(self):
try:
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
user_collection_settings = config.get("collection_settings", {})
# Default HNSW index settings for ChromaDB

View File

@@ -23,14 +23,17 @@ class FaissVectorDBStorage(BaseVectorStorage):
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
"""
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
cosine_better_than_threshold: float = None
def __post_init__(self):
# Grab config values if available
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
# Where to save index file if you want persistent storage
self._faiss_index_file = os.path.join(

View File

@@ -47,7 +47,9 @@ class GremlinStorage(BaseGraphStorage):
# All vertices will have graph={GRAPH} property, so that we can
# have several logical graphs for one source
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
GRAPH = GremlinStorage._to_value_map(
os.environ.get("GREMLIN_GRAPH", "LightRAG")
)
self.graph_name = GRAPH

View File

@@ -5,16 +5,22 @@ from dataclasses import dataclass
import numpy as np
from lightrag.utils import logger
from ..base import BaseVectorStorage
import pipmaster as pm
import configparser
if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
from pymilvus import MilvusClient
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass
class MilvusVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod
def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs
@@ -26,15 +32,37 @@ class MilvusVectorDBStorage(BaseVectorStorage):
)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._client = MilvusClient(
uri=os.environ.get(
"MILVUS_URI",
os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
config.get(
"milvus",
"uri",
fallback=os.path.join(
self.global_config["working_dir"], "milvus_lite.db"
),
),
),
user=os.environ.get(
"MILVUS_USER", config.get("milvus", "user", fallback=None)
),
password=os.environ.get(
"MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)
),
token=os.environ.get(
"MILVUS_TOKEN", config.get("milvus", "token", fallback=None)
),
db_name=os.environ.get(
"MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)
),
user=os.environ.get("MILVUS_USER", ""),
password=os.environ.get("MILVUS_PASSWORD", ""),
token=os.environ.get("MILVUS_TOKEN", ""),
db_name=os.environ.get("MILVUS_DB_NAME", ""),
)
self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorage.create_collection_if_not_exist(
@@ -85,7 +113,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
data=embedding,
limit=top_k,
output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
search_params={
"metric_type": "COSINE",
"params": {"radius": self.cosine_better_than_threshold},
},
)
print(results)
return [

View File

@@ -1,8 +1,8 @@
import os
from dataclasses import dataclass
import numpy as np
import pipmaster as pm
import configparser
from tqdm.asyncio import tqdm as tqdm_async
if not pm.is_installed("pymongo"):
@@ -12,7 +12,6 @@ if not pm.is_installed("motor"):
pm.install("motor")
from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
@@ -27,13 +26,27 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
)
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as KV {self.namespace}")
@@ -173,10 +186,25 @@ class MongoGraphStorage(BaseGraphStorage):
embedding_func=embedding_func,
)
self.client = AsyncIOMotorClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
)
self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")]
self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")]
self.db = self.client[
os.environ.get(
"MONGO_DATABASE",
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
)
]
self.collection = self.db[
os.environ.get(
"MONGO_KG_COLLECTION",
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
)
]
#
# -------------------------------------------------------------------------

View File

@@ -73,16 +73,19 @@ from lightrag.base import (
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
cosine_better_than_threshold: float = None
def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -139,9 +142,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
logger.info(
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
)
results = self._client.query(
query=embedding,
top_k=top_k,

View File

@@ -5,6 +5,7 @@ import re
from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict
import pipmaster as pm
import configparser
if not pm.is_installed("neo4j"):
pm.install("neo4j")
@@ -28,6 +29,10 @@ from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass
class Neo4JStorage(BaseGraphStorage):
@staticmethod
@@ -42,13 +47,22 @@ class Neo4JStorage(BaseGraphStorage):
)
self._driver = None
self._driver_lock = asyncio.Lock()
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)]
USERNAME = os.environ[
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
]
PASSWORD = os.environ[
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
]
MAX_CONNECTION_POOL_SIZE = os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=800),
)
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
)
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)

View File

@@ -1,6 +1,5 @@
import array
import asyncio
import os
# import html
# import os
@@ -172,8 +171,8 @@ class OracleDB:
@dataclass
class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db
db: OracleDB = None
# db instance must be injected before use
# db: OracleDB
meta_fields = None
def __post_init__(self):
@@ -318,16 +317,18 @@ class OracleKVStorage(BaseKVStorage):
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db
db: OracleDB = None
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
# db instance must be injected before use
# db: OracleDB
cosine_better_than_threshold: float = None
def __post_init__(self):
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据"""
@@ -361,7 +362,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
@dataclass
class OracleGraphStorage(BaseGraphStorage):
"""基于Oracle的图存储模块"""
# db instance must be injected before use
# db: OracleDB
def __post_init__(self):
"""从graphml文件加载图"""

View File

@@ -177,7 +177,8 @@ class PostgreSQLDB:
@dataclass
class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None
# db instance must be injected before use
# db: PostgreSQLDB
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -296,16 +297,19 @@ class PGKVStorage(BaseKVStorage):
@dataclass
class PGVectorStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
db: PostgreSQLDB = None
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = None
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
def _upsert_chunks(self, item: dict):
try:
@@ -416,20 +420,14 @@ class PGVectorStorage(BaseVectorStorage):
@dataclass
class PGDocStatusStorage(DocStatusStorage):
"""PostgreSQL implementation of document status storage"""
db: PostgreSQLDB = None
def __post_init__(self):
pass
# db instance must be injected before use
# db: PostgreSQLDB
async def filter_keys(self, data: set[str]) -> set[str]:
"""Return keys that don't exist in storage"""
keys = ",".join([f"'{_id}'" for _id in data])
sql = (
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})"
)
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
result = await self.db.query(sql, multirows=True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if result is None:
return set(data)
@@ -585,19 +583,15 @@ class PGGraphQueryException(Exception):
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = None
# db instance must be injected before use
# db: PostgreSQLDB
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self.graph_name = os.environ["AGE_GRAPH_NAME"]
def __post_init__(self):
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
@@ -1137,7 +1131,7 @@ TABLES = {
"ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
workspace varchar(255) NOT NULL,
id varchar(255) NOT NULL,
content TEXT,
content TEXT NULL,
content_summary varchar(255) NULL,
content_length int4 NULL,
chunks_count int4 NULL,

View File

@@ -1,138 +0,0 @@
import asyncio
import sys
import os
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import asyncpg
import psycopg
from psycopg_pool import AsyncConnectionPool
from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from ..namespace import NameSpace
DB = "rag"
USER = "rag"
PASSWORD = "rag"
HOST = "localhost"
PORT = "15432"
os.environ["AGE_GRAPH_NAME"] = "dickens"
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def get_pool():
return await asyncpg.create_pool(
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
min_size=10,
max_size=10,
max_queries=5000,
max_inactive_connection_lifetime=300.0,
)
async def main1():
connection_string = (
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
)
pool = AsyncConnectionPool(connection_string, open=False)
await pool.open()
try:
conn = await pool.getconn(timeout=10)
async with conn.cursor() as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute("SELECT create_graph('dickens-2')")
await conn.commit()
print("create_graph success")
except (
psycopg.errors.InvalidSchemaName,
psycopg.errors.UniqueViolation,
):
print("create_graph already exists")
await conn.rollback()
finally:
pass
db = PostgreSQLDB(
config={
"host": "localhost",
"port": 15432,
"user": "rag",
"password": "rag",
"database": "r1",
}
)
async def query_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)
graph.db = db
res = await graph.get_node('"A CHRISTMAS CAROL"')
print("Node is: ", res)
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
print("Edge is: ", res)
res = await graph.get_node_edges('"SCROOGE"')
print("Node Edges are: ", res)
async def create_edge_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)
graph.db = db
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
await graph.upsert_edge(
'"THE CRATCHITS"',
'"THE GIRLS"',
edge_data={
"weight": 7.0,
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
"keywords": '"family, collective effort"',
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
},
)
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
print("Edge is: ", res)
async def main():
pool = await get_pool()
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
async with pool.acquire() as conn:
try:
await conn.execute(
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
)
except asyncpg.exceptions.InvalidSchemaNameError:
print("create_graph already exists")
# stmt = await conn.prepare(sql)
row = await conn.fetch(sql)
print("row is: ", row)
row = await conn.fetchrow("select '100'::int + 200 as result")
print(row) # <Record result=300>
if __name__ == "__main__":
asyncio.run(query_with_age())

View File

@@ -5,11 +5,10 @@ from dataclasses import dataclass
import numpy as np
import hashlib
import uuid
from ..utils import logger
from ..base import BaseVectorStorage
import pipmaster as pm
import configparser
if not pm.is_installed("qdrant_client"):
pm.install("qdrant_client")
@@ -17,6 +16,10 @@ if not pm.is_installed("qdrant_client"):
from qdrant_client import QdrantClient, models
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple"
) -> str:
@@ -47,6 +50,8 @@ def compute_mdhash_id_for_qdrant(
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod
def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs
@@ -56,9 +61,21 @@ class QdrantVectorDBStorage(BaseVectorStorage):
client.create_collection(collection_name, **kwargs)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._client = QdrantClient(
url=os.environ.get("QDRANT_URL"),
api_key=os.environ.get("QDRANT_API_KEY", None),
url=os.environ.get(
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
),
api_key=os.environ.get(
"QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)
),
)
self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist(
@@ -122,4 +139,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
limit=top_k,
with_payload=True,
)
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
logger.debug(f"query result: {results}")
# 添加余弦相似度过滤
filtered_results = [
dp for dp in results if dp.score >= self.cosine_better_than_threshold
]
return [
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
]

View File

@@ -3,6 +3,7 @@ from typing import Any, Union
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import pipmaster as pm
import configparser
if not pm.is_installed("redis"):
pm.install("redis")
@@ -14,10 +15,16 @@ from lightrag.base import BaseKVStorage
import json
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass
class RedisKVStorage(BaseKVStorage):
def __post_init__(self):
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}")

View File

@@ -101,7 +101,9 @@ class TiDB:
@dataclass
class TiDBKVStorage(BaseKVStorage):
# should pass db object to self.db
# db instance must be injected before use
# db: TiDB
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -208,18 +210,22 @@ class TiDBKVStorage(BaseKVStorage):
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
# db instance must be injected before use
# db: TiDB
cosine_better_than_threshold: float = None
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
async def query(self, query: str, top_k: int) -> list[dict]:
"""Search from tidb vector"""
@@ -329,6 +335,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: TiDB
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]

View File

@@ -1,5 +1,6 @@
import asyncio
import os
import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
@@ -36,6 +37,111 @@ from .utils import (
)
from .types import KnowledgeGraph
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorge",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"],
"required_methods": ["get_pending_docs"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorge": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
@@ -140,6 +246,9 @@ class LightRAG:
graph_storage: str = field(default="NetworkXStorage")
"""Storage backend for knowledge graphs."""
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Logging
current_log_level = logger.level
log_level: int = field(default=current_log_level)
@@ -236,9 +345,6 @@ class LightRAG:
convert_response_to_json
)
doc_status_storage: str = field(default="JsonDocStatusStorage")
"""Storage type for tracking document processing statuses."""
# Custom Chunking Function
chunking_func: Callable[
[
@@ -252,6 +358,46 @@ class LightRAG:
list[dict[str, Any]],
] = chunking_by_token_size
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True)
log_file = os.path.join(self.log_dir, "lightrag.log")
@@ -263,6 +409,29 @@ class LightRAG:
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
# Verify storage implementation compatibility and environment variables
storage_configs = [
("KV_STORAGE", self.kv_storage),
("VECTOR_STORAGE", self.vector_storage),
("GRAPH_STORAGE", self.graph_storage),
("DOC_STATUS_STORAGE", self.doc_status_storage),
]
for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility
self.verify_storage_implementation(storage_type, storage_name)
# Check environment variables
self.check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields
default_vector_db_kwargs = {
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
}
self.vector_db_storage_cls_kwargs = {
**default_vector_db_kwargs,
**self.vector_db_storage_cls_kwargs,
}
# show config
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
@@ -296,10 +465,8 @@ class LightRAG:
self.graph_storage_cls, global_config=global_config
)
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace=self.namespace_prefix + "json_doc_status_storage",
embedding_func=None,
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace=make_namespace(
@@ -308,9 +475,6 @@ class LightRAG:
embedding_func=self.embedding_func,
)
####
# add embedding func by walter
####
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
@@ -329,9 +493,6 @@ class LightRAG:
),
embedding_func=self.embedding_func,
)
####
# add embedding func by walter over
####
self.entities_vdb = self.vector_db_storage_cls(
namespace=make_namespace(
@@ -354,6 +515,14 @@ class LightRAG:
embedding_func=self.embedding_func,
)
# Initialize document status storage
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
# What's for, Is this nessisary ?
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
@@ -374,14 +543,6 @@ class LightRAG:
)
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
)
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text
@@ -399,7 +560,8 @@ class LightRAG:
return storage_class
def set_storage_client(self, db_client):
# Now only tested on Oracle Database
# Deprecated, seting correct value to *_storage of LightRAG insteaded
# Inject db to storage implementation (only tested on Oracle Database)
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,

View File

@@ -1055,6 +1055,9 @@ async def _get_node_data(
query_param: QueryParam,
):
# get similar entities
logger.info(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
)
results = await entities_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return "", "", ""
@@ -1270,6 +1273,9 @@ async def _get_edge_data(
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
):
logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results):

View File

@@ -416,7 +416,13 @@ async def get_best_cached_response(
if best_similarity > similarity_threshold:
# If LLM check is enabled and all required parameters are provided
if use_llm_check and llm_func and original_prompt and best_prompt:
if (
use_llm_check
and llm_func
and original_prompt
and best_prompt
and best_response is not None
):
compare_prompt = PROMPTS["similarity_check"].format(
original_prompt=original_prompt, cached_prompt=best_prompt
)
@@ -430,7 +436,9 @@ async def get_best_cached_response(
best_similarity = llm_similarity
if best_similarity < similarity_threshold:
log_data = {
"event": "llm_check_cache_rejected",
"event": "cache_rejected_by_llm",
"type": cache_type,
"mode": mode,
"original_question": original_prompt[:100] + "..."
if len(original_prompt) > 100
else original_prompt,
@@ -440,7 +448,8 @@ async def get_best_cached_response(
"similarity_score": round(best_similarity, 4),
"threshold": similarity_threshold,
}
logger.info(json.dumps(log_data, ensure_ascii=False))
logger.debug(json.dumps(log_data, ensure_ascii=False))
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
return None
except Exception as e: # Catch all possible exceptions
logger.warning(f"LLM similarity check failed: {e}")
@@ -451,12 +460,13 @@ async def get_best_cached_response(
)
log_data = {
"event": "cache_hit",
"type": cache_type,
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.info(json.dumps(log_data, ensure_ascii=False))
logger.debug(json.dumps(log_data, ensure_ascii=False))
return best_response
return None
@@ -534,19 +544,24 @@ async def handle_cache(
cache_type=cache_type,
)
if best_cached_response is not None:
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
return best_cached_response, None, None, None
else:
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
return None, quantized, min_val, max_val
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
# Use regular cache
# For default mode or is_embedding_cache_enabled is False, use regular cache
# default mode is for extract_entities or naive query
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
return mode_cache[args_hash]["return"], None, None, None
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
return None, None, None, None