Merge pull request #744 from danielaskdd/select-datastore-in-api-server
Add datastore selection feature for API Server
This commit is contained in:
121
.env.example
121
.env.example
@@ -1,12 +1,30 @@
|
||||
# Server Configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=9621
|
||||
### Server Configuration
|
||||
#HOST=0.0.0.0
|
||||
#PORT=9621
|
||||
#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
|
||||
|
||||
# Directory Configuration
|
||||
WORKING_DIR=/app/data/rag_storage
|
||||
INPUT_DIR=/app/data/inputs
|
||||
### Optional SSL Configuration
|
||||
#SSL=true
|
||||
#SSL_CERTFILE=/path/to/cert.pem
|
||||
#SSL_KEYFILE=/path/to/key.pem
|
||||
|
||||
# RAG Configuration
|
||||
### Security (empty for no api-key is needed)
|
||||
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||
|
||||
### Directory Configuration
|
||||
# WORKING_DIR=./rag_storage
|
||||
# INPUT_DIR=./inputs
|
||||
|
||||
### Logging level
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
### Optional Timeout
|
||||
TIMEOUT=300
|
||||
|
||||
# Ollama Emulating Model Tag
|
||||
# OLLAMA_EMULATING_MODEL_TAG=latest
|
||||
|
||||
### RAG Configuration
|
||||
MAX_ASYNC=4
|
||||
MAX_TOKENS=32768
|
||||
EMBEDDING_DIM=1024
|
||||
@@ -14,56 +32,42 @@ MAX_EMBED_TOKENS=8192
|
||||
#HISTORY_TURNS=3
|
||||
#CHUNK_SIZE=1200
|
||||
#CHUNK_OVERLAP_SIZE=100
|
||||
#COSINE_THRESHOLD=0.4 # 0.2 while not running API server
|
||||
#TOP_K=50 # 60 while not running API server
|
||||
#COSINE_THRESHOLD=0.2
|
||||
#TOP_K=60
|
||||
|
||||
# LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
|
||||
# Ollama example
|
||||
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
|
||||
### Ollama example
|
||||
LLM_BINDING=ollama
|
||||
LLM_BINDING_HOST=http://host.docker.internal:11434
|
||||
LLM_MODEL=mistral-nemo:latest
|
||||
|
||||
# OpenAI alike example
|
||||
### OpenAI alike example
|
||||
# LLM_BINDING=openai
|
||||
# LLM_MODEL=deepseek-chat
|
||||
# LLM_BINDING_HOST=https://api.deepseek.com
|
||||
# LLM_BINDING_API_KEY=your_api_key
|
||||
|
||||
# for OpenAI LLM (LLM_BINDING_API_KEY take priority)
|
||||
### for OpenAI LLM (LLM_BINDING_API_KEY take priority)
|
||||
# OPENAI_API_KEY=your_api_key
|
||||
|
||||
# Lollms example
|
||||
### Lollms example
|
||||
# LLM_BINDING=lollms
|
||||
# LLM_BINDING_HOST=http://host.docker.internal:9600
|
||||
# LLM_MODEL=mistral-nemo:latest
|
||||
|
||||
|
||||
# Embedding Configuration (Use valid host. For local services, you can use host.docker.internal)
|
||||
### Embedding Configuration (Use valid host. For local services, you can use host.docker.internal)
|
||||
# Ollama example
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_BINDING_HOST=http://host.docker.internal:11434
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
|
||||
# Lollms example
|
||||
### Lollms example
|
||||
# EMBEDDING_BINDING=lollms
|
||||
# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
|
||||
# EMBEDDING_MODEL=bge-m3:latest
|
||||
|
||||
# Security (empty for no key)
|
||||
LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Optional SSL Configuration
|
||||
#SSL=true
|
||||
#SSL_CERTFILE=/path/to/cert.pem
|
||||
#SSL_KEYFILE=/path/to/key.pem
|
||||
|
||||
# Optional Timeout
|
||||
#TIMEOUT=30
|
||||
|
||||
# Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
|
||||
### Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
|
||||
# AZURE_OPENAI_API_VERSION=2024-08-01-preview
|
||||
# AZURE_OPENAI_DEPLOYMENT=gpt-4o
|
||||
# AZURE_OPENAI_API_KEY=myapikey
|
||||
@@ -72,6 +76,57 @@ LOG_LEVEL=INFO
|
||||
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
||||
# AZURE_EMBEDDING_API_VERSION=2023-05-15
|
||||
|
||||
### Data storage selection
|
||||
# LIGHTRAG_KV_STORAGE=PGKVStorage
|
||||
# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage
|
||||
# LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
|
||||
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
|
||||
|
||||
# Ollama Emulating Model Tag
|
||||
# OLLAMA_EMULATING_MODEL_TAG=latest
|
||||
### Oracle Database Configuration
|
||||
ORACLE_DSN=localhost:1521/XEPDB1
|
||||
ORACLE_USER=your_username
|
||||
ORACLE_PASSWORD='your_password'
|
||||
ORACLE_CONFIG_DIR=/path/to/oracle/config
|
||||
#ORACLE_WALLET_LOCATION=/path/to/wallet # optional
|
||||
#ORACLE_WALLET_PASSWORD='your_password' # optional
|
||||
#ORACLE_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
|
||||
|
||||
### TiDB Configuration
|
||||
TIDB_HOST=localhost
|
||||
TIDB_PORT=4000
|
||||
TIDB_USER=your_username
|
||||
TIDB_PASSWORD='your_password'
|
||||
TIDB_DATABASE=your_database
|
||||
#TIDB_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
|
||||
|
||||
### PostgreSQL Configuration
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_USER=your_username
|
||||
POSTGRES_PASSWORD='your_password'
|
||||
POSTGRES_DATABASE=your_database
|
||||
#POSTGRES_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
|
||||
|
||||
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
||||
AGE_POSTGRES_DB=
|
||||
AGE_POSTGRES_USER=
|
||||
AGE_POSTGRES_PASSWORD=
|
||||
AGE_POSTGRES_HOST=
|
||||
# AGE_POSTGRES_PORT=8529
|
||||
|
||||
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
||||
# AGE_GRAPH_NAME=lightrag # deprecated, use NAME_SPACE_PREFIX instead
|
||||
|
||||
### Neo4j Configuration
|
||||
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||
NEO4J_USERNAME=neo4j
|
||||
NEO4J_PASSWORD='your_password'
|
||||
|
||||
### MongoDB Configuration
|
||||
MONGODB_URI=mongodb://root:root@localhost:27017/
|
||||
MONGODB_DATABASE=LightRAG
|
||||
MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
|
||||
|
||||
### Qdrant
|
||||
QDRANT_URL=http://localhost:16333
|
||||
QDRANT_API_KEY=your-api-key # 可选
|
||||
|
@@ -13,3 +13,28 @@ uri=redis://localhost:6379/1
|
||||
|
||||
[qdrant]
|
||||
uri = http://localhost:16333
|
||||
|
||||
[oracle]
|
||||
dsn = localhost:1521/XEPDB1
|
||||
user = your_username
|
||||
password = your_password
|
||||
config_dir = /path/to/oracle/config
|
||||
wallet_location = /path/to/wallet # 可选
|
||||
wallet_password = your_wallet_password # 可选
|
||||
workspace = default # 可选,默认为default
|
||||
|
||||
[tidb]
|
||||
host = localhost
|
||||
port = 4000
|
||||
user = your_username
|
||||
password = your_password
|
||||
database = your_database
|
||||
workspace = default # 可选,默认为default
|
||||
|
||||
[postgres]
|
||||
host = localhost
|
||||
port = 5432
|
||||
user = your_username
|
||||
password = your_password
|
||||
database = your_database
|
||||
workspace = default # 可选,默认为default
|
||||
|
@@ -103,66 +103,23 @@ After starting the lightrag-server, you can add an Ollama-type connection in the
|
||||
|
||||
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
|
||||
|
||||
For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM.
|
||||
Default `TOP_K` is set to `60`. Default `COSINE_THRESHOLD` are set to `0.2`.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables:
|
||||
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. A sample file `.env.example` is provided for your convenience.
|
||||
|
||||
```env
|
||||
# Server Configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=9621
|
||||
### Config.ini
|
||||
|
||||
# Directory Configuration
|
||||
WORKING_DIR=/app/data/rag_storage
|
||||
INPUT_DIR=/app/data/inputs
|
||||
|
||||
# RAG Configuration
|
||||
MAX_ASYNC=4
|
||||
MAX_TOKENS=32768
|
||||
EMBEDDING_DIM=1024
|
||||
MAX_EMBED_TOKENS=8192
|
||||
#HISTORY_TURNS=3
|
||||
#CHUNK_SIZE=1200
|
||||
#CHUNK_OVERLAP_SIZE=100
|
||||
#COSINE_THRESHOLD=0.4
|
||||
#TOP_K=50
|
||||
|
||||
# LLM Configuration
|
||||
LLM_BINDING=ollama
|
||||
LLM_BINDING_HOST=http://localhost:11434
|
||||
LLM_MODEL=mistral-nemo:latest
|
||||
|
||||
# must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms)
|
||||
OPENAI_API_KEY=you_api_key
|
||||
|
||||
# Embedding Configuration
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
|
||||
# Security
|
||||
#LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Optional SSL Configuration
|
||||
#SSL=true
|
||||
#SSL_CERTFILE=/path/to/cert.pem
|
||||
#SSL_KEYFILE=/path/to/key.pem
|
||||
|
||||
# Optional Timeout
|
||||
#TIMEOUT=30
|
||||
```
|
||||
Datastorage configuration can be also set by config.ini. A sample file `config.ini.example` is provided for your convenience.
|
||||
|
||||
### Configuration Priority
|
||||
|
||||
The configuration values are loaded in the following order (highest priority first):
|
||||
1. Command-line arguments
|
||||
2. Environment variables
|
||||
3. Default values
|
||||
3. Config.ini
|
||||
4. Defaul values
|
||||
|
||||
For example:
|
||||
```bash
|
||||
@@ -173,7 +130,69 @@ python lightrag.py --port 8080
|
||||
PORT=7000 python lightrag.py
|
||||
```
|
||||
|
||||
#### LightRag Server Options
|
||||
> Best practices: you can set your database setting in Config.ini while testing, and you use .env for production.
|
||||
|
||||
### Storage Types Supported
|
||||
|
||||
LightRAG uses 4 types of storage for difference purposes:
|
||||
|
||||
* KV_STORAGE:llm response cache, text chunks, document information
|
||||
* VECTOR_STORAGE:entities vectors, relation vectors, chunks vectors
|
||||
* GRAPH_STORAGE:entity relation graph
|
||||
* DOC_STATUS_STORAGE:documents indexing status
|
||||
|
||||
Each storage type have servals implementations:
|
||||
|
||||
* KV_STORAGE supported implement-name
|
||||
|
||||
```
|
||||
JsonKVStorage JsonFile(default)
|
||||
MongoKVStorage MogonDB
|
||||
RedisKVStorage Redis
|
||||
TiDBKVStorage TiDB
|
||||
PGKVStorage Postgres
|
||||
OracleKVStorage Oracle
|
||||
```
|
||||
|
||||
* GRAPH_STORAGE supported implement-name
|
||||
|
||||
```
|
||||
NetworkXStorage NetworkX(defualt)
|
||||
Neo4JStorage Neo4J
|
||||
MongoGraphStorage MongoDB
|
||||
TiDBGraphStorage TiDB
|
||||
AGEStorage AGE
|
||||
GremlinStorage Gremlin
|
||||
PGGraphStorage Postgres
|
||||
OracleGraphStorage Postgres
|
||||
```
|
||||
|
||||
* VECTOR_STORAGE supported implement-name
|
||||
|
||||
```
|
||||
NanoVectorDBStorage NanoVector(default)
|
||||
MilvusVectorDBStorge Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
FaissVectorDBStorage Faiss
|
||||
QdrantVectorDBStorage Qdrant
|
||||
OracleVectorDBStorag Oracle
|
||||
```
|
||||
|
||||
* DOC_STATUS_STORAGE:supported implement-name
|
||||
|
||||
```
|
||||
JsonDocStatusStorage JsonFile(default)
|
||||
PGDocStatusStorage Postgres
|
||||
MongoDocStatusStorage MongoDB
|
||||
```
|
||||
|
||||
### How Select Storage Implementation
|
||||
|
||||
You can select storage implementation by enviroment variables or command line arguments. You can not change storage implementation selection after you add documents to LightRAG. Data migration from one storage implementation to anthor is not supported yet. For further information please read the sample env file or config.ini file.
|
||||
|
||||
### LightRag API Server Comand Line Options
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
@@ -200,6 +219,10 @@ PORT=7000 python lightrag.py
|
||||
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
|
||||
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
|
||||
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
|
||||
| --kv-storage | JsonKVStorage | implement-name of KV_STORAGE |
|
||||
| --graph-storage | NetworkXStorage | implement-name of GRAPH_STORAGE |
|
||||
| --vector-storage | NanoVectorDBStorage | implement-name of VECTOR_STORAGE |
|
||||
| --doc-status-storage | JsonDocStatusStorage | implement-name of DOC_STATUS_STORAGE |
|
||||
|
||||
### Example Usage
|
||||
|
||||
@@ -343,6 +366,14 @@ curl -X POST "http://localhost:9621/documents/scan" --max-time 1800
|
||||
|
||||
> Ajust max-time according to the estimated index time for all new files.
|
||||
|
||||
#### DELETE /documents
|
||||
|
||||
Clear all documents from the RAG system.
|
||||
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:9621/documents"
|
||||
```
|
||||
|
||||
### Ollama Emulation Endpoints
|
||||
|
||||
#### GET /api/version
|
||||
@@ -372,14 +403,6 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso
|
||||
|
||||
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
||||
|
||||
#### DELETE /documents
|
||||
|
||||
Clear all documents from the RAG system.
|
||||
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:9621/documents"
|
||||
```
|
||||
|
||||
### Utility Endpoints
|
||||
|
||||
#### GET /health
|
||||
|
@@ -1 +1 @@
|
||||
__api_version__ = "1.0.4"
|
||||
__api_version__ = "1.0.5"
|
||||
|
@@ -26,7 +26,6 @@ import shutil
|
||||
import aiofiles
|
||||
from ascii_colors import trace_exception, ASCIIColors
|
||||
import sys
|
||||
import configparser
|
||||
from fastapi import Depends, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -34,25 +33,47 @@ from contextlib import asynccontextmanager
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
import pipmaster as pm
|
||||
from dotenv import load_dotenv
|
||||
import configparser
|
||||
from lightrag.utils import logger
|
||||
from .ollama_api import (
|
||||
OllamaAPI,
|
||||
)
|
||||
from .ollama_api import ollama_server_infos
|
||||
from ..kg.postgres_impl import (
|
||||
PostgreSQLDB,
|
||||
PGKVStorage,
|
||||
PGVectorStorage,
|
||||
PGGraphStorage,
|
||||
PGDocStatusStorage,
|
||||
)
|
||||
from ..kg.oracle_impl import (
|
||||
OracleDB,
|
||||
OracleKVStorage,
|
||||
OracleVectorDBStorage,
|
||||
OracleGraphStorage,
|
||||
)
|
||||
from ..kg.tidb_impl import (
|
||||
TiDB,
|
||||
TiDBKVStorage,
|
||||
TiDBVectorDBStorage,
|
||||
TiDBGraphStorage,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Initialize config parser
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
|
||||
class RAGStorageConfig:
|
||||
|
||||
class DefaultRAGStorageConfig:
|
||||
KV_STORAGE = "JsonKVStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
|
||||
|
||||
# Initialize rag storage config
|
||||
rag_storage_config = RAGStorageConfig()
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
@@ -80,61 +101,6 @@ def estimate_tokens(text: str) -> int:
|
||||
|
||||
return int(tokens)
|
||||
|
||||
|
||||
# read config.ini
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
# Redis config
|
||||
redis_uri = config.get("redis", "uri", fallback=None)
|
||||
if redis_uri:
|
||||
os.environ["REDIS_URI"] = redis_uri
|
||||
rag_storage_config.KV_STORAGE = "RedisKVStorage"
|
||||
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
|
||||
|
||||
# Neo4j config
|
||||
neo4j_uri = config.get("neo4j", "uri", fallback=None)
|
||||
neo4j_username = config.get("neo4j", "username", fallback=None)
|
||||
neo4j_password = config.get("neo4j", "password", fallback=None)
|
||||
if neo4j_uri:
|
||||
os.environ["NEO4J_URI"] = neo4j_uri
|
||||
os.environ["NEO4J_USERNAME"] = neo4j_username
|
||||
os.environ["NEO4J_PASSWORD"] = neo4j_password
|
||||
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
|
||||
|
||||
# Milvus config
|
||||
milvus_uri = config.get("milvus", "uri", fallback=None)
|
||||
milvus_user = config.get("milvus", "user", fallback=None)
|
||||
milvus_password = config.get("milvus", "password", fallback=None)
|
||||
milvus_db_name = config.get("milvus", "db_name", fallback=None)
|
||||
if milvus_uri:
|
||||
os.environ["MILVUS_URI"] = milvus_uri
|
||||
os.environ["MILVUS_USER"] = milvus_user
|
||||
os.environ["MILVUS_PASSWORD"] = milvus_password
|
||||
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
||||
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorage"
|
||||
|
||||
# Qdrant config
|
||||
qdrant_uri = config.get("qdrant", "uri", fallback=None)
|
||||
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
|
||||
if qdrant_uri:
|
||||
os.environ["QDRANT_URL"] = qdrant_uri
|
||||
if qdrant_api_key:
|
||||
os.environ["QDRANT_API_KEY"] = qdrant_api_key
|
||||
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
|
||||
|
||||
# MongoDB config
|
||||
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
||||
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
|
||||
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
|
||||
if mongo_uri:
|
||||
os.environ["MONGO_URI"] = mongo_uri
|
||||
os.environ["MONGO_DATABASE"] = mongo_database
|
||||
rag_storage_config.KV_STORAGE = "MongoKVStorage"
|
||||
rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
|
||||
if mongo_graph:
|
||||
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
default_hosts = {
|
||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||
@@ -247,6 +213,16 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.top_k}")
|
||||
|
||||
# System Configuration
|
||||
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
||||
ASCIIColors.white(" ├─ KV Storage: ", end="")
|
||||
ASCIIColors.yellow(f"{args.kv_storage}")
|
||||
ASCIIColors.white(" ├─ Vector Storage: ", end="")
|
||||
ASCIIColors.yellow(f"{args.vector_storage}")
|
||||
ASCIIColors.white(" ├─ Graph Storage: ", end="")
|
||||
ASCIIColors.yellow(f"{args.graph_storage}")
|
||||
ASCIIColors.white(" └─ Document Status Storage: ", end="")
|
||||
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
||||
|
||||
ASCIIColors.magenta("\n🛠️ System Configuration:")
|
||||
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
||||
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
||||
@@ -344,6 +320,35 @@ def parse_args() -> argparse.Namespace:
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--kv-storage",
|
||||
default=get_env_value(
|
||||
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||
),
|
||||
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--doc-status-storage",
|
||||
default=get_env_value(
|
||||
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||
),
|
||||
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--graph-storage",
|
||||
default=get_env_value(
|
||||
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||
),
|
||||
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vector-storage",
|
||||
default=get_env_value(
|
||||
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||
),
|
||||
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
||||
)
|
||||
|
||||
# Bindings configuration
|
||||
parser.add_argument(
|
||||
"--llm-binding",
|
||||
@@ -528,13 +533,13 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=get_env_value("TOP_K", 50, int),
|
||||
help="Number of most similar results to return (default: from env or 50)",
|
||||
default=get_env_value("TOP_K", 60, int),
|
||||
help="Number of most similar results to return (default: from env or 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cosine-threshold",
|
||||
type=float,
|
||||
default=get_env_value("COSINE_THRESHOLD", 0.4, float),
|
||||
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||
)
|
||||
|
||||
@@ -667,7 +672,14 @@ def get_api_key_dependency(api_key: Optional[str]):
|
||||
return api_key_auth
|
||||
|
||||
|
||||
# Global configuration
|
||||
global_top_k = 60 # default value
|
||||
|
||||
|
||||
def create_app(args):
|
||||
global global_top_k
|
||||
global_top_k = args.top_k # save top_k from args
|
||||
|
||||
# Verify that bindings are correctly setup
|
||||
if args.llm_binding not in [
|
||||
"lollms",
|
||||
@@ -713,25 +725,104 @@ def create_app(args):
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events"""
|
||||
# Startup logic
|
||||
if args.auto_scan_at_startup:
|
||||
try:
|
||||
new_files = doc_manager.scan_directory_for_new_files()
|
||||
for file_path in new_files:
|
||||
try:
|
||||
await index_file(file_path)
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||
# Initialize database connections
|
||||
postgres_db = None
|
||||
oracle_db = None
|
||||
tidb_db = None
|
||||
|
||||
ASCIIColors.info(
|
||||
f"Indexed {len(new_files)} documents from {args.input_dir}"
|
||||
try:
|
||||
# Check if PostgreSQL is needed
|
||||
if any(
|
||||
isinstance(
|
||||
storage_instance,
|
||||
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during startup indexing: {str(e)}")
|
||||
yield
|
||||
# Cleanup logic (if needed)
|
||||
pass
|
||||
for _, storage_instance in storage_instances
|
||||
):
|
||||
postgres_db = PostgreSQLDB(_get_postgres_config())
|
||||
await postgres_db.initdb()
|
||||
await postgres_db.check_tables()
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
if isinstance(
|
||||
storage_instance,
|
||||
(
|
||||
PGKVStorage,
|
||||
PGVectorStorage,
|
||||
PGGraphStorage,
|
||||
PGDocStatusStorage,
|
||||
),
|
||||
):
|
||||
storage_instance.db = postgres_db
|
||||
logger.info(f"Injected postgres_db to {storage_name}")
|
||||
|
||||
# Check if Oracle is needed
|
||||
if any(
|
||||
isinstance(
|
||||
storage_instance,
|
||||
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||
)
|
||||
for _, storage_instance in storage_instances
|
||||
):
|
||||
oracle_db = OracleDB(_get_oracle_config())
|
||||
await oracle_db.check_tables()
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
if isinstance(
|
||||
storage_instance,
|
||||
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||
):
|
||||
storage_instance.db = oracle_db
|
||||
logger.info(f"Injected oracle_db to {storage_name}")
|
||||
|
||||
# Check if TiDB is needed
|
||||
if any(
|
||||
isinstance(
|
||||
storage_instance,
|
||||
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
||||
)
|
||||
for _, storage_instance in storage_instances
|
||||
):
|
||||
tidb_db = TiDB(_get_tidb_config())
|
||||
await tidb_db.check_tables()
|
||||
for storage_name, storage_instance in storage_instances:
|
||||
if isinstance(
|
||||
storage_instance,
|
||||
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
||||
):
|
||||
storage_instance.db = tidb_db
|
||||
logger.info(f"Injected tidb_db to {storage_name}")
|
||||
|
||||
# Auto scan documents if enabled
|
||||
if args.auto_scan_at_startup:
|
||||
try:
|
||||
new_files = doc_manager.scan_directory_for_new_files()
|
||||
for file_path in new_files:
|
||||
try:
|
||||
await index_file(file_path)
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||
|
||||
ASCIIColors.info(
|
||||
f"Indexed {len(new_files)} documents from {args.input_dir}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during startup indexing: {str(e)}")
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Cleanup database connections
|
||||
if postgres_db and hasattr(postgres_db, "pool"):
|
||||
await postgres_db.pool.close()
|
||||
logger.info("Closed PostgreSQL connection pool")
|
||||
|
||||
if oracle_db and hasattr(oracle_db, "pool"):
|
||||
await oracle_db.pool.close()
|
||||
logger.info("Closed Oracle connection pool")
|
||||
|
||||
if tidb_db and hasattr(tidb_db, "pool"):
|
||||
await tidb_db.pool.close()
|
||||
logger.info("Closed TiDB connection pool")
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
@@ -754,6 +845,92 @@ def create_app(args):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Database configuration functions
|
||||
def _get_postgres_config():
|
||||
return {
|
||||
"host": os.environ.get(
|
||||
"POSTGRES_HOST",
|
||||
config.get("postgres", "host", fallback="localhost"),
|
||||
),
|
||||
"port": os.environ.get(
|
||||
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
||||
),
|
||||
"user": os.environ.get(
|
||||
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"POSTGRES_PASSWORD",
|
||||
config.get("postgres", "password", fallback=None),
|
||||
),
|
||||
"database": os.environ.get(
|
||||
"POSTGRES_DATABASE",
|
||||
config.get("postgres", "database", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"POSTGRES_WORKSPACE",
|
||||
config.get("postgres", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
def _get_oracle_config():
|
||||
return {
|
||||
"user": os.environ.get(
|
||||
"ORACLE_USER",
|
||||
config.get("oracle", "user", fallback=None),
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"ORACLE_PASSWORD",
|
||||
config.get("oracle", "password", fallback=None),
|
||||
),
|
||||
"dsn": os.environ.get(
|
||||
"ORACLE_DSN",
|
||||
config.get("oracle", "dsn", fallback=None),
|
||||
),
|
||||
"config_dir": os.environ.get(
|
||||
"ORACLE_CONFIG_DIR",
|
||||
config.get("oracle", "config_dir", fallback=None),
|
||||
),
|
||||
"wallet_location": os.environ.get(
|
||||
"ORACLE_WALLET_LOCATION",
|
||||
config.get("oracle", "wallet_location", fallback=None),
|
||||
),
|
||||
"wallet_password": os.environ.get(
|
||||
"ORACLE_WALLET_PASSWORD",
|
||||
config.get("oracle", "wallet_password", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"ORACLE_WORKSPACE",
|
||||
config.get("oracle", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
def _get_tidb_config():
|
||||
return {
|
||||
"host": os.environ.get(
|
||||
"TIDB_HOST",
|
||||
config.get("tidb", "host", fallback="localhost"),
|
||||
),
|
||||
"port": os.environ.get(
|
||||
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
||||
),
|
||||
"user": os.environ.get(
|
||||
"TIDB_USER",
|
||||
config.get("tidb", "user", fallback=None),
|
||||
),
|
||||
"password": os.environ.get(
|
||||
"TIDB_PASSWORD",
|
||||
config.get("tidb", "password", fallback=None),
|
||||
),
|
||||
"database": os.environ.get(
|
||||
"TIDB_DATABASE",
|
||||
config.get("tidb", "database", fallback=None),
|
||||
),
|
||||
"workspace": os.environ.get(
|
||||
"TIDB_WORKSPACE",
|
||||
config.get("tidb", "workspace", fallback="default"),
|
||||
),
|
||||
}
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
@@ -872,10 +1049,10 @@ def create_app(args):
|
||||
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
||||
else {},
|
||||
embedding_func=embedding_func,
|
||||
kv_storage=rag_storage_config.KV_STORAGE,
|
||||
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
||||
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
||||
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
||||
kv_storage=args.kv_storage,
|
||||
graph_storage=args.graph_storage,
|
||||
vector_storage=args.vector_storage,
|
||||
doc_status_storage=args.doc_status_storage,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
@@ -903,10 +1080,10 @@ def create_app(args):
|
||||
llm_model_max_async=args.max_async,
|
||||
llm_model_max_token_size=args.max_tokens,
|
||||
embedding_func=embedding_func,
|
||||
kv_storage=rag_storage_config.KV_STORAGE,
|
||||
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
||||
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
||||
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
||||
kv_storage=args.kv_storage,
|
||||
graph_storage=args.graph_storage,
|
||||
vector_storage=args.vector_storage,
|
||||
doc_status_storage=args.doc_status_storage,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
@@ -920,6 +1097,18 @@ def create_app(args):
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
)
|
||||
|
||||
# Collect all storage instances
|
||||
storage_instances = [
|
||||
("full_docs", rag.full_docs),
|
||||
("text_chunks", rag.text_chunks),
|
||||
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
|
||||
("entities_vdb", rag.entities_vdb),
|
||||
("relationships_vdb", rag.relationships_vdb),
|
||||
("chunks_vdb", rag.chunks_vdb),
|
||||
("doc_status", rag.doc_status),
|
||||
("llm_response_cache", rag.llm_response_cache),
|
||||
]
|
||||
|
||||
async def index_file(file_path: Union[str, Path]) -> None:
|
||||
"""Index all files inside the folder with support for multiple file formats
|
||||
|
||||
@@ -1100,7 +1289,7 @@ def create_app(args):
|
||||
mode=request.mode,
|
||||
stream=request.stream,
|
||||
only_need_context=request.only_need_context,
|
||||
top_k=args.top_k,
|
||||
top_k=global_top_k,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1142,7 +1331,7 @@ def create_app(args):
|
||||
mode=request.mode,
|
||||
stream=True,
|
||||
only_need_context=request.only_need_context,
|
||||
top_k=args.top_k,
|
||||
top_k=global_top_k,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1432,7 +1621,7 @@ def create_app(args):
|
||||
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
|
||||
|
||||
# Add Ollama API routes
|
||||
ollama_api = OllamaAPI(rag)
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||
app.include_router(ollama_api.router, prefix="/api")
|
||||
|
||||
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
||||
@@ -1460,10 +1649,10 @@ def create_app(args):
|
||||
"embedding_binding_host": args.embedding_binding_host,
|
||||
"embedding_model": args.embedding_model,
|
||||
"max_tokens": args.max_tokens,
|
||||
"kv_storage": rag_storage_config.KV_STORAGE,
|
||||
"doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
|
||||
"graph_storage": rag_storage_config.GRAPH_STORAGE,
|
||||
"vector_storage": rag_storage_config.VECTOR_STORAGE,
|
||||
"kv_storage": args.kv_storage,
|
||||
"doc_status_storage": args.doc_status_storage,
|
||||
"graph_storage": args.graph_storage,
|
||||
"vector_storage": args.vector_storage,
|
||||
},
|
||||
}
|
||||
|
||||
|
@@ -148,9 +148,10 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||
|
||||
|
||||
class OllamaAPI:
|
||||
def __init__(self, rag: LightRAG):
|
||||
def __init__(self, rag: LightRAG, top_k: int = 60):
|
||||
self.rag = rag
|
||||
self.ollama_server_infos = ollama_server_infos
|
||||
self.top_k = top_k
|
||||
self.router = APIRouter()
|
||||
self.setup_routes()
|
||||
|
||||
@@ -381,7 +382,7 @@ class OllamaAPI:
|
||||
"stream": request.stream,
|
||||
"only_need_context": False,
|
||||
"conversation_history": conversation_history,
|
||||
"top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
if (
|
||||
|
@@ -75,8 +75,8 @@ class AGEStorage(BaseGraphStorage):
|
||||
.replace("'", "\\'")
|
||||
)
|
||||
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
|
||||
PORT = int(os.environ["AGE_POSTGRES_PORT"])
|
||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
||||
PORT = os.environ.get("AGE_POSTGRES_PORT", "8529")
|
||||
self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||
|
||||
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
@@ -13,15 +12,17 @@ from lightrag.utils import logger
|
||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
"""ChromaDB vector storage implementation."""
|
||||
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
try:
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
user_collection_settings = config.get("collection_settings", {})
|
||||
# Default HNSW index settings for ChromaDB
|
||||
|
@@ -23,14 +23,17 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||
"""
|
||||
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Grab config values if available
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
# Where to save index file if you want persistent storage
|
||||
self._faiss_index_file = os.path.join(
|
||||
|
@@ -47,7 +47,9 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
# All vertices will have graph={GRAPH} property, so that we can
|
||||
# have several logical graphs for one source
|
||||
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
|
||||
GRAPH = GremlinStorage._to_value_map(
|
||||
os.environ.get("GREMLIN_GRAPH", "LightRAG")
|
||||
)
|
||||
|
||||
self.graph_name = GRAPH
|
||||
|
||||
|
@@ -5,16 +5,22 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: MilvusClient, collection_name: str, **kwargs
|
||||
@@ -26,15 +32,37 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
self._client = MilvusClient(
|
||||
uri=os.environ.get(
|
||||
"MILVUS_URI",
|
||||
os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
|
||||
config.get(
|
||||
"milvus",
|
||||
"uri",
|
||||
fallback=os.path.join(
|
||||
self.global_config["working_dir"], "milvus_lite.db"
|
||||
),
|
||||
),
|
||||
),
|
||||
user=os.environ.get(
|
||||
"MILVUS_USER", config.get("milvus", "user", fallback=None)
|
||||
),
|
||||
password=os.environ.get(
|
||||
"MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)
|
||||
),
|
||||
token=os.environ.get(
|
||||
"MILVUS_TOKEN", config.get("milvus", "token", fallback=None)
|
||||
),
|
||||
db_name=os.environ.get(
|
||||
"MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)
|
||||
),
|
||||
user=os.environ.get("MILVUS_USER", ""),
|
||||
password=os.environ.get("MILVUS_PASSWORD", ""),
|
||||
token=os.environ.get("MILVUS_TOKEN", ""),
|
||||
db_name=os.environ.get("MILVUS_DB_NAME", ""),
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||
@@ -85,7 +113,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
data=embedding,
|
||||
limit=top_k,
|
||||
output_fields=list(self.meta_fields),
|
||||
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
|
||||
search_params={
|
||||
"metric_type": "COSINE",
|
||||
"params": {"radius": self.cosine_better_than_threshold},
|
||||
},
|
||||
)
|
||||
print(results)
|
||||
return [
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
|
||||
if not pm.is_installed("pymongo"):
|
||||
@@ -12,7 +12,6 @@ if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
|
||||
@@ -27,13 +26,27 @@ from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
client = MongoClient(
|
||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
||||
os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
|
||||
self._data = database.get_collection(self.namespace)
|
||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||
|
||||
@@ -173,10 +186,25 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.client = AsyncIOMotorClient(
|
||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
||||
os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
)
|
||||
self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")]
|
||||
self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")]
|
||||
self.db = self.client[
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
]
|
||||
self.collection = self.db[
|
||||
os.environ.get(
|
||||
"MONGO_KG_COLLECTION",
|
||||
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
|
||||
)
|
||||
]
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
|
@@ -73,16 +73,19 @@ from lightrag.base import (
|
||||
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._save_lock = asyncio.Lock()
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
@@ -139,9 +142,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
logger.info(
|
||||
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
|
||||
)
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
|
@@ -5,6 +5,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, Tuple, List, Dict
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
@@ -28,6 +29,10 @@ from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
@@ -42,13 +47,22 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
URI = os.environ["NEO4J_URI"]
|
||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
|
||||
|
||||
URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)]
|
||||
USERNAME = os.environ[
|
||||
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||
]
|
||||
PASSWORD = os.environ[
|
||||
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
|
||||
]
|
||||
MAX_CONNECTION_POOL_SIZE = os.environ.get(
|
||||
"NEO4J_MAX_CONNECTION_POOL_SIZE",
|
||||
config.get("neo4j", "connection_pool_size", fallback=800),
|
||||
)
|
||||
DATABASE = os.environ.get(
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||
)
|
||||
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
URI, auth=(USERNAME, PASSWORD)
|
||||
)
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import array
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
# import html
|
||||
# import os
|
||||
@@ -172,8 +171,8 @@ class OracleDB:
|
||||
|
||||
@dataclass
|
||||
class OracleKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
db: OracleDB = None
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
meta_fields = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -318,16 +317,18 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# should pass db object to self.db
|
||||
db: OracleDB = None
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""向向量数据库中插入数据"""
|
||||
@@ -361,7 +362,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
@dataclass
|
||||
class OracleGraphStorage(BaseGraphStorage):
|
||||
"""基于Oracle的图存储模块"""
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
|
||||
def __post_init__(self):
|
||||
"""从graphml文件加载图"""
|
||||
|
@@ -177,7 +177,8 @@ class PostgreSQLDB:
|
||||
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
db: PostgreSQLDB = None
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -296,16 +297,19 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
db: PostgreSQLDB = None
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
def _upsert_chunks(self, item: dict):
|
||||
try:
|
||||
@@ -416,20 +420,14 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
|
||||
@dataclass
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
"""PostgreSQL implementation of document status storage"""
|
||||
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
keys = ",".join([f"'{_id}'" for _id in data])
|
||||
sql = (
|
||||
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})"
|
||||
)
|
||||
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||
result = await self.db.query(sql, multirows=True)
|
||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||
if result is None:
|
||||
return set(data)
|
||||
@@ -585,19 +583,15 @@ class PGGraphQueryException(Exception):
|
||||
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
db: PostgreSQLDB = None
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
||||
def __post_init__(self):
|
||||
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
@@ -1137,7 +1131,7 @@ TABLES = {
|
||||
"ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
|
||||
workspace varchar(255) NOT NULL,
|
||||
id varchar(255) NOT NULL,
|
||||
content TEXT,
|
||||
content TEXT NULL,
|
||||
content_summary varchar(255) NULL,
|
||||
content_length int4 NULL,
|
||||
chunks_count int4 NULL,
|
||||
|
@@ -1,138 +0,0 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("psycopg-pool"):
|
||||
pm.install("psycopg-pool")
|
||||
pm.install("psycopg[binary,pool]")
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import asyncpg
|
||||
import psycopg
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
|
||||
from ..namespace import NameSpace
|
||||
|
||||
DB = "rag"
|
||||
USER = "rag"
|
||||
PASSWORD = "rag"
|
||||
HOST = "localhost"
|
||||
PORT = "15432"
|
||||
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
import asyncio.windows_events
|
||||
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
async def get_pool():
|
||||
return await asyncpg.create_pool(
|
||||
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
||||
min_size=10,
|
||||
max_size=10,
|
||||
max_queries=5000,
|
||||
max_inactive_connection_lifetime=300.0,
|
||||
)
|
||||
|
||||
|
||||
async def main1():
|
||||
connection_string = (
|
||||
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
||||
)
|
||||
pool = AsyncConnectionPool(connection_string, open=False)
|
||||
await pool.open()
|
||||
|
||||
try:
|
||||
conn = await pool.getconn(timeout=10)
|
||||
async with conn.cursor() as curs:
|
||||
try:
|
||||
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await curs.execute("SELECT create_graph('dickens-2')")
|
||||
await conn.commit()
|
||||
print("create_graph success")
|
||||
except (
|
||||
psycopg.errors.InvalidSchemaName,
|
||||
psycopg.errors.UniqueViolation,
|
||||
):
|
||||
print("create_graph already exists")
|
||||
await conn.rollback()
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
db = PostgreSQLDB(
|
||||
config={
|
||||
"host": "localhost",
|
||||
"port": 15432,
|
||||
"user": "rag",
|
||||
"password": "rag",
|
||||
"database": "r1",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def query_with_age():
|
||||
await db.initdb()
|
||||
graph = PGGraphStorage(
|
||||
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
|
||||
global_config={},
|
||||
embedding_func=None,
|
||||
)
|
||||
graph.db = db
|
||||
res = await graph.get_node('"A CHRISTMAS CAROL"')
|
||||
print("Node is: ", res)
|
||||
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
|
||||
print("Edge is: ", res)
|
||||
res = await graph.get_node_edges('"SCROOGE"')
|
||||
print("Node Edges are: ", res)
|
||||
|
||||
|
||||
async def create_edge_with_age():
|
||||
await db.initdb()
|
||||
graph = PGGraphStorage(
|
||||
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
|
||||
global_config={},
|
||||
embedding_func=None,
|
||||
)
|
||||
graph.db = db
|
||||
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
|
||||
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
|
||||
await graph.upsert_edge(
|
||||
'"THE CRATCHITS"',
|
||||
'"THE GIRLS"',
|
||||
edge_data={
|
||||
"weight": 7.0,
|
||||
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
|
||||
"keywords": '"family, collective effort"',
|
||||
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
||||
},
|
||||
)
|
||||
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
|
||||
print("Edge is: ", res)
|
||||
|
||||
|
||||
async def main():
|
||||
pool = await get_pool()
|
||||
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
||||
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
||||
async with pool.acquire() as conn:
|
||||
try:
|
||||
await conn.execute(
|
||||
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
|
||||
)
|
||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
||||
print("create_graph already exists")
|
||||
# stmt = await conn.prepare(sql)
|
||||
row = await conn.fetch(sql)
|
||||
print("row is: ", row)
|
||||
|
||||
row = await conn.fetchrow("select '100'::int + 200 as result")
|
||||
print(row) # <Record result=300>
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(query_with_age())
|
@@ -5,11 +5,10 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("qdrant_client"):
|
||||
pm.install("qdrant_client")
|
||||
@@ -17,6 +16,10 @@ if not pm.is_installed("qdrant_client"):
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
content: str, prefix: str = "", style: str = "simple"
|
||||
) -> str:
|
||||
@@ -47,6 +50,8 @@ def compute_mdhash_id_for_qdrant(
|
||||
|
||||
@dataclass
|
||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: QdrantClient, collection_name: str, **kwargs
|
||||
@@ -56,9 +61,21 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
client.create_collection(collection_name, **kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
self._client = QdrantClient(
|
||||
url=os.environ.get("QDRANT_URL"),
|
||||
api_key=os.environ.get("QDRANT_API_KEY", None),
|
||||
url=os.environ.get(
|
||||
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
|
||||
),
|
||||
api_key=os.environ.get(
|
||||
"QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)
|
||||
),
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||
@@ -122,4 +139,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
limit=top_k,
|
||||
with_payload=True,
|
||||
)
|
||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
||||
logger.debug(f"query result: {results}")
|
||||
# 添加余弦相似度过滤
|
||||
filtered_results = [
|
||||
dp for dp in results if dp.score >= self.cosine_better_than_threshold
|
||||
]
|
||||
return [
|
||||
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
|
||||
]
|
||||
|
@@ -3,6 +3,7 @@ from typing import Any, Union
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
import configparser
|
||||
|
||||
if not pm.is_installed("redis"):
|
||||
pm.install("redis")
|
||||
@@ -14,10 +15,16 @@ from lightrag.base import BaseKVStorage
|
||||
import json
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
|
||||
redis_url = os.environ.get(
|
||||
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||
)
|
||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
logger.info(f"Use Redis as KV {self.namespace}")
|
||||
|
||||
|
@@ -101,7 +101,9 @@ class TiDB:
|
||||
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -208,18 +210,22 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||
if cosine_threshold is None:
|
||||
raise ValueError(
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
"""Search from tidb vector"""
|
||||
@@ -329,6 +335,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import configparser
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
@@ -36,6 +37,111 @@ from .utils import (
|
||||
)
|
||||
from .types import KnowledgeGraph
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
# Storage type and implementation compatibility validation table
|
||||
STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
"GRAPH_STORAGE": {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
"VECTOR_STORAGE": {
|
||||
"implementations": [
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorge",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
"DOC_STATUS_STORAGE": {
|
||||
"implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"],
|
||||
"required_methods": ["get_pending_docs"],
|
||||
},
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorge": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"MongoDocStatusStorage": [],
|
||||
}
|
||||
|
||||
# Storage implementation module mapping
|
||||
STORAGES = {
|
||||
"NetworkXStorage": ".kg.networkx_impl",
|
||||
"JsonKVStorage": ".kg.json_kv_impl",
|
||||
@@ -140,6 +246,9 @@ class LightRAG:
|
||||
graph_storage: str = field(default="NetworkXStorage")
|
||||
"""Storage backend for knowledge graphs."""
|
||||
|
||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||
"""Storage type for tracking document processing statuses."""
|
||||
|
||||
# Logging
|
||||
current_log_level = logger.level
|
||||
log_level: int = field(default=current_log_level)
|
||||
@@ -236,9 +345,6 @@ class LightRAG:
|
||||
convert_response_to_json
|
||||
)
|
||||
|
||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||
"""Storage type for tracking document processing statuses."""
|
||||
|
||||
# Custom Chunking Function
|
||||
chunking_func: Callable[
|
||||
[
|
||||
@@ -252,6 +358,46 @@ class LightRAG:
|
||||
list[dict[str, Any]],
|
||||
] = chunking_by_token_size
|
||||
|
||||
def verify_storage_implementation(
|
||||
self, storage_type: str, storage_name: str
|
||||
) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
||||
@@ -263,6 +409,29 @@ class LightRAG:
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
# Verify storage implementation compatibility and environment variables
|
||||
storage_configs = [
|
||||
("KV_STORAGE", self.kv_storage),
|
||||
("VECTOR_STORAGE", self.vector_storage),
|
||||
("GRAPH_STORAGE", self.graph_storage),
|
||||
("DOC_STATUS_STORAGE", self.doc_status_storage),
|
||||
]
|
||||
|
||||
for storage_type, storage_name in storage_configs:
|
||||
# Verify storage implementation compatibility
|
||||
self.verify_storage_implementation(storage_type, storage_name)
|
||||
# Check environment variables
|
||||
self.check_storage_env_vars(storage_name)
|
||||
|
||||
# Ensure vector_db_storage_cls_kwargs has required fields
|
||||
default_vector_db_kwargs = {
|
||||
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
}
|
||||
self.vector_db_storage_cls_kwargs = {
|
||||
**default_vector_db_kwargs,
|
||||
**self.vector_db_storage_cls_kwargs,
|
||||
}
|
||||
|
||||
# show config
|
||||
global_config = asdict(self)
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||
@@ -296,10 +465,8 @@ class LightRAG:
|
||||
self.graph_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
|
||||
namespace=self.namespace_prefix + "json_doc_status_storage",
|
||||
embedding_func=None,
|
||||
)
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
@@ -308,9 +475,6 @@ class LightRAG:
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
####
|
||||
# add embedding func by walter
|
||||
####
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||
@@ -329,9 +493,6 @@ class LightRAG:
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
####
|
||||
# add embedding func by walter over
|
||||
####
|
||||
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
namespace=make_namespace(
|
||||
@@ -354,6 +515,14 @@ class LightRAG:
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
||||
global_config=global_config,
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
# What's for, Is this nessisary ?
|
||||
if self.llm_response_cache and hasattr(
|
||||
self.llm_response_cache, "global_config"
|
||||
):
|
||||
@@ -374,14 +543,6 @@ class LightRAG:
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
||||
global_config=global_config,
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
async def get_graph_labels(self):
|
||||
text = await self.chunk_entity_relation_graph.get_all_labels()
|
||||
return text
|
||||
@@ -399,7 +560,8 @@ class LightRAG:
|
||||
return storage_class
|
||||
|
||||
def set_storage_client(self, db_client):
|
||||
# Now only tested on Oracle Database
|
||||
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
||||
# Inject db to storage implementation (only tested on Oracle Database)
|
||||
for storage in [
|
||||
self.vector_db_storage_cls,
|
||||
self.graph_storage_cls,
|
||||
|
@@ -1055,6 +1055,9 @@ async def _get_node_data(
|
||||
query_param: QueryParam,
|
||||
):
|
||||
# get similar entities
|
||||
logger.info(
|
||||
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||||
)
|
||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
@@ -1270,6 +1273,9 @@ async def _get_edge_data(
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
):
|
||||
logger.info(
|
||||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||
)
|
||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||
|
||||
if not len(results):
|
||||
|
@@ -416,7 +416,13 @@ async def get_best_cached_response(
|
||||
|
||||
if best_similarity > similarity_threshold:
|
||||
# If LLM check is enabled and all required parameters are provided
|
||||
if use_llm_check and llm_func and original_prompt and best_prompt:
|
||||
if (
|
||||
use_llm_check
|
||||
and llm_func
|
||||
and original_prompt
|
||||
and best_prompt
|
||||
and best_response is not None
|
||||
):
|
||||
compare_prompt = PROMPTS["similarity_check"].format(
|
||||
original_prompt=original_prompt, cached_prompt=best_prompt
|
||||
)
|
||||
@@ -430,7 +436,9 @@ async def get_best_cached_response(
|
||||
best_similarity = llm_similarity
|
||||
if best_similarity < similarity_threshold:
|
||||
log_data = {
|
||||
"event": "llm_check_cache_rejected",
|
||||
"event": "cache_rejected_by_llm",
|
||||
"type": cache_type,
|
||||
"mode": mode,
|
||||
"original_question": original_prompt[:100] + "..."
|
||||
if len(original_prompt) > 100
|
||||
else original_prompt,
|
||||
@@ -440,7 +448,8 @@ async def get_best_cached_response(
|
||||
"similarity_score": round(best_similarity, 4),
|
||||
"threshold": similarity_threshold,
|
||||
}
|
||||
logger.info(json.dumps(log_data, ensure_ascii=False))
|
||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
||||
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
||||
return None
|
||||
except Exception as e: # Catch all possible exceptions
|
||||
logger.warning(f"LLM similarity check failed: {e}")
|
||||
@@ -451,12 +460,13 @@ async def get_best_cached_response(
|
||||
)
|
||||
log_data = {
|
||||
"event": "cache_hit",
|
||||
"type": cache_type,
|
||||
"mode": mode,
|
||||
"similarity": round(best_similarity, 4),
|
||||
"cache_id": best_cache_id,
|
||||
"original_prompt": prompt_display,
|
||||
}
|
||||
logger.info(json.dumps(log_data, ensure_ascii=False))
|
||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
||||
return best_response
|
||||
return None
|
||||
|
||||
@@ -534,19 +544,24 @@ async def handle_cache(
|
||||
cache_type=cache_type,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
||||
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
|
||||
# Use regular cache
|
||||
# For default mode or is_embedding_cache_enabled is False, use regular cache
|
||||
# default mode is for extract_entities or naive query
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user