Merge pull request #744 from danielaskdd/select-datastore-in-api-server
Add datastore selection feature for API Server
This commit is contained in:
121
.env.example
121
.env.example
@@ -1,12 +1,30 @@
|
|||||||
# Server Configuration
|
### Server Configuration
|
||||||
HOST=0.0.0.0
|
#HOST=0.0.0.0
|
||||||
PORT=9621
|
#PORT=9621
|
||||||
|
#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
|
||||||
|
|
||||||
# Directory Configuration
|
### Optional SSL Configuration
|
||||||
WORKING_DIR=/app/data/rag_storage
|
#SSL=true
|
||||||
INPUT_DIR=/app/data/inputs
|
#SSL_CERTFILE=/path/to/cert.pem
|
||||||
|
#SSL_KEYFILE=/path/to/key.pem
|
||||||
|
|
||||||
# RAG Configuration
|
### Security (empty for no api-key is needed)
|
||||||
|
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||||
|
|
||||||
|
### Directory Configuration
|
||||||
|
# WORKING_DIR=./rag_storage
|
||||||
|
# INPUT_DIR=./inputs
|
||||||
|
|
||||||
|
### Logging level
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
### Optional Timeout
|
||||||
|
TIMEOUT=300
|
||||||
|
|
||||||
|
# Ollama Emulating Model Tag
|
||||||
|
# OLLAMA_EMULATING_MODEL_TAG=latest
|
||||||
|
|
||||||
|
### RAG Configuration
|
||||||
MAX_ASYNC=4
|
MAX_ASYNC=4
|
||||||
MAX_TOKENS=32768
|
MAX_TOKENS=32768
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
@@ -14,56 +32,42 @@ MAX_EMBED_TOKENS=8192
|
|||||||
#HISTORY_TURNS=3
|
#HISTORY_TURNS=3
|
||||||
#CHUNK_SIZE=1200
|
#CHUNK_SIZE=1200
|
||||||
#CHUNK_OVERLAP_SIZE=100
|
#CHUNK_OVERLAP_SIZE=100
|
||||||
#COSINE_THRESHOLD=0.4 # 0.2 while not running API server
|
#COSINE_THRESHOLD=0.2
|
||||||
#TOP_K=50 # 60 while not running API server
|
#TOP_K=60
|
||||||
|
|
||||||
# LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
|
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
|
||||||
# Ollama example
|
### Ollama example
|
||||||
LLM_BINDING=ollama
|
LLM_BINDING=ollama
|
||||||
LLM_BINDING_HOST=http://host.docker.internal:11434
|
LLM_BINDING_HOST=http://host.docker.internal:11434
|
||||||
LLM_MODEL=mistral-nemo:latest
|
LLM_MODEL=mistral-nemo:latest
|
||||||
|
|
||||||
# OpenAI alike example
|
### OpenAI alike example
|
||||||
# LLM_BINDING=openai
|
# LLM_BINDING=openai
|
||||||
# LLM_MODEL=deepseek-chat
|
# LLM_MODEL=deepseek-chat
|
||||||
# LLM_BINDING_HOST=https://api.deepseek.com
|
# LLM_BINDING_HOST=https://api.deepseek.com
|
||||||
# LLM_BINDING_API_KEY=your_api_key
|
# LLM_BINDING_API_KEY=your_api_key
|
||||||
|
|
||||||
# for OpenAI LLM (LLM_BINDING_API_KEY take priority)
|
### for OpenAI LLM (LLM_BINDING_API_KEY take priority)
|
||||||
# OPENAI_API_KEY=your_api_key
|
# OPENAI_API_KEY=your_api_key
|
||||||
|
|
||||||
# Lollms example
|
### Lollms example
|
||||||
# LLM_BINDING=lollms
|
# LLM_BINDING=lollms
|
||||||
# LLM_BINDING_HOST=http://host.docker.internal:9600
|
# LLM_BINDING_HOST=http://host.docker.internal:9600
|
||||||
# LLM_MODEL=mistral-nemo:latest
|
# LLM_MODEL=mistral-nemo:latest
|
||||||
|
|
||||||
|
|
||||||
# Embedding Configuration (Use valid host. For local services, you can use host.docker.internal)
|
### Embedding Configuration (Use valid host. For local services, you can use host.docker.internal)
|
||||||
# Ollama example
|
# Ollama example
|
||||||
EMBEDDING_BINDING=ollama
|
EMBEDDING_BINDING=ollama
|
||||||
EMBEDDING_BINDING_HOST=http://host.docker.internal:11434
|
EMBEDDING_BINDING_HOST=http://host.docker.internal:11434
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
EMBEDDING_MODEL=bge-m3:latest
|
||||||
|
|
||||||
# Lollms example
|
### Lollms example
|
||||||
# EMBEDDING_BINDING=lollms
|
# EMBEDDING_BINDING=lollms
|
||||||
# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
|
# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
|
||||||
# EMBEDDING_MODEL=bge-m3:latest
|
# EMBEDDING_MODEL=bge-m3:latest
|
||||||
|
|
||||||
# Security (empty for no key)
|
### Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
|
||||||
LIGHTRAG_API_KEY=your-secure-api-key-here
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
LOG_LEVEL=INFO
|
|
||||||
|
|
||||||
# Optional SSL Configuration
|
|
||||||
#SSL=true
|
|
||||||
#SSL_CERTFILE=/path/to/cert.pem
|
|
||||||
#SSL_KEYFILE=/path/to/key.pem
|
|
||||||
|
|
||||||
# Optional Timeout
|
|
||||||
#TIMEOUT=30
|
|
||||||
|
|
||||||
# Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority)
|
|
||||||
# AZURE_OPENAI_API_VERSION=2024-08-01-preview
|
# AZURE_OPENAI_API_VERSION=2024-08-01-preview
|
||||||
# AZURE_OPENAI_DEPLOYMENT=gpt-4o
|
# AZURE_OPENAI_DEPLOYMENT=gpt-4o
|
||||||
# AZURE_OPENAI_API_KEY=myapikey
|
# AZURE_OPENAI_API_KEY=myapikey
|
||||||
@@ -72,6 +76,57 @@ LOG_LEVEL=INFO
|
|||||||
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
||||||
# AZURE_EMBEDDING_API_VERSION=2023-05-15
|
# AZURE_EMBEDDING_API_VERSION=2023-05-15
|
||||||
|
|
||||||
|
### Data storage selection
|
||||||
|
# LIGHTRAG_KV_STORAGE=PGKVStorage
|
||||||
|
# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage
|
||||||
|
# LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
|
||||||
|
# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
|
||||||
|
|
||||||
# Ollama Emulating Model Tag
|
### Oracle Database Configuration
|
||||||
# OLLAMA_EMULATING_MODEL_TAG=latest
|
ORACLE_DSN=localhost:1521/XEPDB1
|
||||||
|
ORACLE_USER=your_username
|
||||||
|
ORACLE_PASSWORD='your_password'
|
||||||
|
ORACLE_CONFIG_DIR=/path/to/oracle/config
|
||||||
|
#ORACLE_WALLET_LOCATION=/path/to/wallet # optional
|
||||||
|
#ORACLE_WALLET_PASSWORD='your_password' # optional
|
||||||
|
#ORACLE_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
|
||||||
|
|
||||||
|
### TiDB Configuration
|
||||||
|
TIDB_HOST=localhost
|
||||||
|
TIDB_PORT=4000
|
||||||
|
TIDB_USER=your_username
|
||||||
|
TIDB_PASSWORD='your_password'
|
||||||
|
TIDB_DATABASE=your_database
|
||||||
|
#TIDB_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
|
||||||
|
|
||||||
|
### PostgreSQL Configuration
|
||||||
|
POSTGRES_HOST=localhost
|
||||||
|
POSTGRES_PORT=5432
|
||||||
|
POSTGRES_USER=your_username
|
||||||
|
POSTGRES_PASSWORD='your_password'
|
||||||
|
POSTGRES_DATABASE=your_database
|
||||||
|
#POSTGRES_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future)
|
||||||
|
|
||||||
|
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
||||||
|
AGE_POSTGRES_DB=
|
||||||
|
AGE_POSTGRES_USER=
|
||||||
|
AGE_POSTGRES_PASSWORD=
|
||||||
|
AGE_POSTGRES_HOST=
|
||||||
|
# AGE_POSTGRES_PORT=8529
|
||||||
|
|
||||||
|
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
||||||
|
# AGE_GRAPH_NAME=lightrag # deprecated, use NAME_SPACE_PREFIX instead
|
||||||
|
|
||||||
|
### Neo4j Configuration
|
||||||
|
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||||
|
NEO4J_USERNAME=neo4j
|
||||||
|
NEO4J_PASSWORD='your_password'
|
||||||
|
|
||||||
|
### MongoDB Configuration
|
||||||
|
MONGODB_URI=mongodb://root:root@localhost:27017/
|
||||||
|
MONGODB_DATABASE=LightRAG
|
||||||
|
MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
|
||||||
|
|
||||||
|
### Qdrant
|
||||||
|
QDRANT_URL=http://localhost:16333
|
||||||
|
QDRANT_API_KEY=your-api-key # 可选
|
||||||
|
@@ -13,3 +13,28 @@ uri=redis://localhost:6379/1
|
|||||||
|
|
||||||
[qdrant]
|
[qdrant]
|
||||||
uri = http://localhost:16333
|
uri = http://localhost:16333
|
||||||
|
|
||||||
|
[oracle]
|
||||||
|
dsn = localhost:1521/XEPDB1
|
||||||
|
user = your_username
|
||||||
|
password = your_password
|
||||||
|
config_dir = /path/to/oracle/config
|
||||||
|
wallet_location = /path/to/wallet # 可选
|
||||||
|
wallet_password = your_wallet_password # 可选
|
||||||
|
workspace = default # 可选,默认为default
|
||||||
|
|
||||||
|
[tidb]
|
||||||
|
host = localhost
|
||||||
|
port = 4000
|
||||||
|
user = your_username
|
||||||
|
password = your_password
|
||||||
|
database = your_database
|
||||||
|
workspace = default # 可选,默认为default
|
||||||
|
|
||||||
|
[postgres]
|
||||||
|
host = localhost
|
||||||
|
port = 5432
|
||||||
|
user = your_username
|
||||||
|
password = your_password
|
||||||
|
database = your_database
|
||||||
|
workspace = default # 可选,默认为default
|
||||||
|
@@ -103,66 +103,23 @@ After starting the lightrag-server, you can add an Ollama-type connection in the
|
|||||||
|
|
||||||
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
|
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
|
||||||
|
|
||||||
For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM.
|
Default `TOP_K` is set to `60`. Default `COSINE_THRESHOLD` are set to `0.2`.
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables:
|
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. A sample file `.env.example` is provided for your convenience.
|
||||||
|
|
||||||
```env
|
### Config.ini
|
||||||
# Server Configuration
|
|
||||||
HOST=0.0.0.0
|
|
||||||
PORT=9621
|
|
||||||
|
|
||||||
# Directory Configuration
|
Datastorage configuration can be also set by config.ini. A sample file `config.ini.example` is provided for your convenience.
|
||||||
WORKING_DIR=/app/data/rag_storage
|
|
||||||
INPUT_DIR=/app/data/inputs
|
|
||||||
|
|
||||||
# RAG Configuration
|
|
||||||
MAX_ASYNC=4
|
|
||||||
MAX_TOKENS=32768
|
|
||||||
EMBEDDING_DIM=1024
|
|
||||||
MAX_EMBED_TOKENS=8192
|
|
||||||
#HISTORY_TURNS=3
|
|
||||||
#CHUNK_SIZE=1200
|
|
||||||
#CHUNK_OVERLAP_SIZE=100
|
|
||||||
#COSINE_THRESHOLD=0.4
|
|
||||||
#TOP_K=50
|
|
||||||
|
|
||||||
# LLM Configuration
|
|
||||||
LLM_BINDING=ollama
|
|
||||||
LLM_BINDING_HOST=http://localhost:11434
|
|
||||||
LLM_MODEL=mistral-nemo:latest
|
|
||||||
|
|
||||||
# must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms)
|
|
||||||
OPENAI_API_KEY=you_api_key
|
|
||||||
|
|
||||||
# Embedding Configuration
|
|
||||||
EMBEDDING_BINDING=ollama
|
|
||||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
|
||||||
|
|
||||||
# Security
|
|
||||||
#LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
LOG_LEVEL=INFO
|
|
||||||
|
|
||||||
# Optional SSL Configuration
|
|
||||||
#SSL=true
|
|
||||||
#SSL_CERTFILE=/path/to/cert.pem
|
|
||||||
#SSL_KEYFILE=/path/to/key.pem
|
|
||||||
|
|
||||||
# Optional Timeout
|
|
||||||
#TIMEOUT=30
|
|
||||||
```
|
|
||||||
|
|
||||||
### Configuration Priority
|
### Configuration Priority
|
||||||
|
|
||||||
The configuration values are loaded in the following order (highest priority first):
|
The configuration values are loaded in the following order (highest priority first):
|
||||||
1. Command-line arguments
|
1. Command-line arguments
|
||||||
2. Environment variables
|
2. Environment variables
|
||||||
3. Default values
|
3. Config.ini
|
||||||
|
4. Defaul values
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
```bash
|
```bash
|
||||||
@@ -173,7 +130,69 @@ python lightrag.py --port 8080
|
|||||||
PORT=7000 python lightrag.py
|
PORT=7000 python lightrag.py
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LightRag Server Options
|
> Best practices: you can set your database setting in Config.ini while testing, and you use .env for production.
|
||||||
|
|
||||||
|
### Storage Types Supported
|
||||||
|
|
||||||
|
LightRAG uses 4 types of storage for difference purposes:
|
||||||
|
|
||||||
|
* KV_STORAGE:llm response cache, text chunks, document information
|
||||||
|
* VECTOR_STORAGE:entities vectors, relation vectors, chunks vectors
|
||||||
|
* GRAPH_STORAGE:entity relation graph
|
||||||
|
* DOC_STATUS_STORAGE:documents indexing status
|
||||||
|
|
||||||
|
Each storage type have servals implementations:
|
||||||
|
|
||||||
|
* KV_STORAGE supported implement-name
|
||||||
|
|
||||||
|
```
|
||||||
|
JsonKVStorage JsonFile(default)
|
||||||
|
MongoKVStorage MogonDB
|
||||||
|
RedisKVStorage Redis
|
||||||
|
TiDBKVStorage TiDB
|
||||||
|
PGKVStorage Postgres
|
||||||
|
OracleKVStorage Oracle
|
||||||
|
```
|
||||||
|
|
||||||
|
* GRAPH_STORAGE supported implement-name
|
||||||
|
|
||||||
|
```
|
||||||
|
NetworkXStorage NetworkX(defualt)
|
||||||
|
Neo4JStorage Neo4J
|
||||||
|
MongoGraphStorage MongoDB
|
||||||
|
TiDBGraphStorage TiDB
|
||||||
|
AGEStorage AGE
|
||||||
|
GremlinStorage Gremlin
|
||||||
|
PGGraphStorage Postgres
|
||||||
|
OracleGraphStorage Postgres
|
||||||
|
```
|
||||||
|
|
||||||
|
* VECTOR_STORAGE supported implement-name
|
||||||
|
|
||||||
|
```
|
||||||
|
NanoVectorDBStorage NanoVector(default)
|
||||||
|
MilvusVectorDBStorge Milvus
|
||||||
|
ChromaVectorDBStorage Chroma
|
||||||
|
TiDBVectorDBStorage TiDB
|
||||||
|
PGVectorStorage Postgres
|
||||||
|
FaissVectorDBStorage Faiss
|
||||||
|
QdrantVectorDBStorage Qdrant
|
||||||
|
OracleVectorDBStorag Oracle
|
||||||
|
```
|
||||||
|
|
||||||
|
* DOC_STATUS_STORAGE:supported implement-name
|
||||||
|
|
||||||
|
```
|
||||||
|
JsonDocStatusStorage JsonFile(default)
|
||||||
|
PGDocStatusStorage Postgres
|
||||||
|
MongoDocStatusStorage MongoDB
|
||||||
|
```
|
||||||
|
|
||||||
|
### How Select Storage Implementation
|
||||||
|
|
||||||
|
You can select storage implementation by enviroment variables or command line arguments. You can not change storage implementation selection after you add documents to LightRAG. Data migration from one storage implementation to anthor is not supported yet. For further information please read the sample env file or config.ini file.
|
||||||
|
|
||||||
|
### LightRag API Server Comand Line Options
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
|-----------|---------|-------------|
|
|-----------|---------|-------------|
|
||||||
@@ -200,6 +219,10 @@ PORT=7000 python lightrag.py
|
|||||||
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
|
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
|
||||||
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
|
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
|
||||||
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
|
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
|
||||||
|
| --kv-storage | JsonKVStorage | implement-name of KV_STORAGE |
|
||||||
|
| --graph-storage | NetworkXStorage | implement-name of GRAPH_STORAGE |
|
||||||
|
| --vector-storage | NanoVectorDBStorage | implement-name of VECTOR_STORAGE |
|
||||||
|
| --doc-status-storage | JsonDocStatusStorage | implement-name of DOC_STATUS_STORAGE |
|
||||||
|
|
||||||
### Example Usage
|
### Example Usage
|
||||||
|
|
||||||
@@ -343,6 +366,14 @@ curl -X POST "http://localhost:9621/documents/scan" --max-time 1800
|
|||||||
|
|
||||||
> Ajust max-time according to the estimated index time for all new files.
|
> Ajust max-time according to the estimated index time for all new files.
|
||||||
|
|
||||||
|
#### DELETE /documents
|
||||||
|
|
||||||
|
Clear all documents from the RAG system.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X DELETE "http://localhost:9621/documents"
|
||||||
|
```
|
||||||
|
|
||||||
### Ollama Emulation Endpoints
|
### Ollama Emulation Endpoints
|
||||||
|
|
||||||
#### GET /api/version
|
#### GET /api/version
|
||||||
@@ -372,14 +403,6 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso
|
|||||||
|
|
||||||
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
||||||
|
|
||||||
#### DELETE /documents
|
|
||||||
|
|
||||||
Clear all documents from the RAG system.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X DELETE "http://localhost:9621/documents"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Utility Endpoints
|
### Utility Endpoints
|
||||||
|
|
||||||
#### GET /health
|
#### GET /health
|
||||||
|
@@ -1 +1 @@
|
|||||||
__api_version__ = "1.0.4"
|
__api_version__ = "1.0.5"
|
||||||
|
@@ -26,7 +26,6 @@ import shutil
|
|||||||
import aiofiles
|
import aiofiles
|
||||||
from ascii_colors import trace_exception, ASCIIColors
|
from ascii_colors import trace_exception, ASCIIColors
|
||||||
import sys
|
import sys
|
||||||
import configparser
|
|
||||||
from fastapi import Depends, Security
|
from fastapi import Depends, Security
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -34,25 +33,47 @@ from contextlib import asynccontextmanager
|
|||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import configparser
|
||||||
|
from lightrag.utils import logger
|
||||||
from .ollama_api import (
|
from .ollama_api import (
|
||||||
OllamaAPI,
|
OllamaAPI,
|
||||||
)
|
)
|
||||||
from .ollama_api import ollama_server_infos
|
from .ollama_api import ollama_server_infos
|
||||||
|
from ..kg.postgres_impl import (
|
||||||
|
PostgreSQLDB,
|
||||||
|
PGKVStorage,
|
||||||
|
PGVectorStorage,
|
||||||
|
PGGraphStorage,
|
||||||
|
PGDocStatusStorage,
|
||||||
|
)
|
||||||
|
from ..kg.oracle_impl import (
|
||||||
|
OracleDB,
|
||||||
|
OracleKVStorage,
|
||||||
|
OracleVectorDBStorage,
|
||||||
|
OracleGraphStorage,
|
||||||
|
)
|
||||||
|
from ..kg.tidb_impl import (
|
||||||
|
TiDB,
|
||||||
|
TiDBKVStorage,
|
||||||
|
TiDBVectorDBStorage,
|
||||||
|
TiDBGraphStorage,
|
||||||
|
)
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
# Initialize config parser
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini")
|
||||||
|
|
||||||
class RAGStorageConfig:
|
|
||||||
|
class DefaultRAGStorageConfig:
|
||||||
KV_STORAGE = "JsonKVStorage"
|
KV_STORAGE = "JsonKVStorage"
|
||||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
|
||||||
GRAPH_STORAGE = "NetworkXStorage"
|
|
||||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||||
|
GRAPH_STORAGE = "NetworkXStorage"
|
||||||
|
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||||
|
|
||||||
|
|
||||||
# Initialize rag storage config
|
|
||||||
rag_storage_config = RAGStorageConfig()
|
|
||||||
|
|
||||||
# Global progress tracker
|
# Global progress tracker
|
||||||
scan_progress: Dict = {
|
scan_progress: Dict = {
|
||||||
"is_scanning": False,
|
"is_scanning": False,
|
||||||
@@ -80,61 +101,6 @@ def estimate_tokens(text: str) -> int:
|
|||||||
|
|
||||||
return int(tokens)
|
return int(tokens)
|
||||||
|
|
||||||
|
|
||||||
# read config.ini
|
|
||||||
config = configparser.ConfigParser()
|
|
||||||
config.read("config.ini", "utf-8")
|
|
||||||
# Redis config
|
|
||||||
redis_uri = config.get("redis", "uri", fallback=None)
|
|
||||||
if redis_uri:
|
|
||||||
os.environ["REDIS_URI"] = redis_uri
|
|
||||||
rag_storage_config.KV_STORAGE = "RedisKVStorage"
|
|
||||||
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
|
|
||||||
|
|
||||||
# Neo4j config
|
|
||||||
neo4j_uri = config.get("neo4j", "uri", fallback=None)
|
|
||||||
neo4j_username = config.get("neo4j", "username", fallback=None)
|
|
||||||
neo4j_password = config.get("neo4j", "password", fallback=None)
|
|
||||||
if neo4j_uri:
|
|
||||||
os.environ["NEO4J_URI"] = neo4j_uri
|
|
||||||
os.environ["NEO4J_USERNAME"] = neo4j_username
|
|
||||||
os.environ["NEO4J_PASSWORD"] = neo4j_password
|
|
||||||
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
|
|
||||||
|
|
||||||
# Milvus config
|
|
||||||
milvus_uri = config.get("milvus", "uri", fallback=None)
|
|
||||||
milvus_user = config.get("milvus", "user", fallback=None)
|
|
||||||
milvus_password = config.get("milvus", "password", fallback=None)
|
|
||||||
milvus_db_name = config.get("milvus", "db_name", fallback=None)
|
|
||||||
if milvus_uri:
|
|
||||||
os.environ["MILVUS_URI"] = milvus_uri
|
|
||||||
os.environ["MILVUS_USER"] = milvus_user
|
|
||||||
os.environ["MILVUS_PASSWORD"] = milvus_password
|
|
||||||
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
|
||||||
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorage"
|
|
||||||
|
|
||||||
# Qdrant config
|
|
||||||
qdrant_uri = config.get("qdrant", "uri", fallback=None)
|
|
||||||
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
|
|
||||||
if qdrant_uri:
|
|
||||||
os.environ["QDRANT_URL"] = qdrant_uri
|
|
||||||
if qdrant_api_key:
|
|
||||||
os.environ["QDRANT_API_KEY"] = qdrant_api_key
|
|
||||||
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
|
|
||||||
|
|
||||||
# MongoDB config
|
|
||||||
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
|
||||||
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
|
|
||||||
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
|
|
||||||
if mongo_uri:
|
|
||||||
os.environ["MONGO_URI"] = mongo_uri
|
|
||||||
os.environ["MONGO_DATABASE"] = mongo_database
|
|
||||||
rag_storage_config.KV_STORAGE = "MongoKVStorage"
|
|
||||||
rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
|
|
||||||
if mongo_graph:
|
|
||||||
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_host(binding_type: str) -> str:
|
def get_default_host(binding_type: str) -> str:
|
||||||
default_hosts = {
|
default_hosts = {
|
||||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||||
@@ -247,6 +213,16 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|||||||
ASCIIColors.yellow(f"{args.top_k}")
|
ASCIIColors.yellow(f"{args.top_k}")
|
||||||
|
|
||||||
# System Configuration
|
# System Configuration
|
||||||
|
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
||||||
|
ASCIIColors.white(" ├─ KV Storage: ", end="")
|
||||||
|
ASCIIColors.yellow(f"{args.kv_storage}")
|
||||||
|
ASCIIColors.white(" ├─ Vector Storage: ", end="")
|
||||||
|
ASCIIColors.yellow(f"{args.vector_storage}")
|
||||||
|
ASCIIColors.white(" ├─ Graph Storage: ", end="")
|
||||||
|
ASCIIColors.yellow(f"{args.graph_storage}")
|
||||||
|
ASCIIColors.white(" └─ Document Status Storage: ", end="")
|
||||||
|
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
||||||
|
|
||||||
ASCIIColors.magenta("\n🛠️ System Configuration:")
|
ASCIIColors.magenta("\n🛠️ System Configuration:")
|
||||||
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
||||||
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
||||||
@@ -344,6 +320,35 @@ def parse_args() -> argparse.Namespace:
|
|||||||
description="LightRAG FastAPI Server with separate working and input directories"
|
description="LightRAG FastAPI Server with separate working and input directories"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-storage",
|
||||||
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||||
|
),
|
||||||
|
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--doc-status-storage",
|
||||||
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||||
|
),
|
||||||
|
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--graph-storage",
|
||||||
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||||
|
),
|
||||||
|
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vector-storage",
|
||||||
|
default=get_env_value(
|
||||||
|
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||||
|
),
|
||||||
|
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
||||||
|
)
|
||||||
|
|
||||||
# Bindings configuration
|
# Bindings configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-binding",
|
"--llm-binding",
|
||||||
@@ -528,13 +533,13 @@ def parse_args() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top-k",
|
"--top-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=get_env_value("TOP_K", 50, int),
|
default=get_env_value("TOP_K", 60, int),
|
||||||
help="Number of most similar results to return (default: from env or 50)",
|
help="Number of most similar results to return (default: from env or 60)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cosine-threshold",
|
"--cosine-threshold",
|
||||||
type=float,
|
type=float,
|
||||||
default=get_env_value("COSINE_THRESHOLD", 0.4, float),
|
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
||||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -667,7 +672,14 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|||||||
return api_key_auth
|
return api_key_auth
|
||||||
|
|
||||||
|
|
||||||
|
# Global configuration
|
||||||
|
global_top_k = 60 # default value
|
||||||
|
|
||||||
|
|
||||||
def create_app(args):
|
def create_app(args):
|
||||||
|
global global_top_k
|
||||||
|
global_top_k = args.top_k # save top_k from args
|
||||||
|
|
||||||
# Verify that bindings are correctly setup
|
# Verify that bindings are correctly setup
|
||||||
if args.llm_binding not in [
|
if args.llm_binding not in [
|
||||||
"lollms",
|
"lollms",
|
||||||
@@ -713,25 +725,104 @@ def create_app(args):
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for startup and shutdown events"""
|
"""Lifespan context manager for startup and shutdown events"""
|
||||||
# Startup logic
|
# Initialize database connections
|
||||||
if args.auto_scan_at_startup:
|
postgres_db = None
|
||||||
try:
|
oracle_db = None
|
||||||
new_files = doc_manager.scan_directory_for_new_files()
|
tidb_db = None
|
||||||
for file_path in new_files:
|
|
||||||
try:
|
|
||||||
await index_file(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
trace_exception(e)
|
|
||||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
|
||||||
|
|
||||||
ASCIIColors.info(
|
try:
|
||||||
f"Indexed {len(new_files)} documents from {args.input_dir}"
|
# Check if PostgreSQL is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
for _, storage_instance in storage_instances
|
||||||
logging.error(f"Error during startup indexing: {str(e)}")
|
):
|
||||||
yield
|
postgres_db = PostgreSQLDB(_get_postgres_config())
|
||||||
# Cleanup logic (if needed)
|
await postgres_db.initdb()
|
||||||
pass
|
await postgres_db.check_tables()
|
||||||
|
for storage_name, storage_instance in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(
|
||||||
|
PGKVStorage,
|
||||||
|
PGVectorStorage,
|
||||||
|
PGGraphStorage,
|
||||||
|
PGDocStatusStorage,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
storage_instance.db = postgres_db
|
||||||
|
logger.info(f"Injected postgres_db to {storage_name}")
|
||||||
|
|
||||||
|
# Check if Oracle is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||||
|
)
|
||||||
|
for _, storage_instance in storage_instances
|
||||||
|
):
|
||||||
|
oracle_db = OracleDB(_get_oracle_config())
|
||||||
|
await oracle_db.check_tables()
|
||||||
|
for storage_name, storage_instance in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
|
||||||
|
):
|
||||||
|
storage_instance.db = oracle_db
|
||||||
|
logger.info(f"Injected oracle_db to {storage_name}")
|
||||||
|
|
||||||
|
# Check if TiDB is needed
|
||||||
|
if any(
|
||||||
|
isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
||||||
|
)
|
||||||
|
for _, storage_instance in storage_instances
|
||||||
|
):
|
||||||
|
tidb_db = TiDB(_get_tidb_config())
|
||||||
|
await tidb_db.check_tables()
|
||||||
|
for storage_name, storage_instance in storage_instances:
|
||||||
|
if isinstance(
|
||||||
|
storage_instance,
|
||||||
|
(TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
|
||||||
|
):
|
||||||
|
storage_instance.db = tidb_db
|
||||||
|
logger.info(f"Injected tidb_db to {storage_name}")
|
||||||
|
|
||||||
|
# Auto scan documents if enabled
|
||||||
|
if args.auto_scan_at_startup:
|
||||||
|
try:
|
||||||
|
new_files = doc_manager.scan_directory_for_new_files()
|
||||||
|
for file_path in new_files:
|
||||||
|
try:
|
||||||
|
await index_file(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
trace_exception(e)
|
||||||
|
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||||
|
|
||||||
|
ASCIIColors.info(
|
||||||
|
f"Indexed {len(new_files)} documents from {args.input_dir}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during startup indexing: {str(e)}")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup database connections
|
||||||
|
if postgres_db and hasattr(postgres_db, "pool"):
|
||||||
|
await postgres_db.pool.close()
|
||||||
|
logger.info("Closed PostgreSQL connection pool")
|
||||||
|
|
||||||
|
if oracle_db and hasattr(oracle_db, "pool"):
|
||||||
|
await oracle_db.pool.close()
|
||||||
|
logger.info("Closed Oracle connection pool")
|
||||||
|
|
||||||
|
if tidb_db and hasattr(tidb_db, "pool"):
|
||||||
|
await tidb_db.pool.close()
|
||||||
|
logger.info("Closed TiDB connection pool")
|
||||||
|
|
||||||
# Initialize FastAPI
|
# Initialize FastAPI
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -754,6 +845,92 @@ def create_app(args):
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Database configuration functions
|
||||||
|
def _get_postgres_config():
|
||||||
|
return {
|
||||||
|
"host": os.environ.get(
|
||||||
|
"POSTGRES_HOST",
|
||||||
|
config.get("postgres", "host", fallback="localhost"),
|
||||||
|
),
|
||||||
|
"port": os.environ.get(
|
||||||
|
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
|
||||||
|
),
|
||||||
|
"user": os.environ.get(
|
||||||
|
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"POSTGRES_PASSWORD",
|
||||||
|
config.get("postgres", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"database": os.environ.get(
|
||||||
|
"POSTGRES_DATABASE",
|
||||||
|
config.get("postgres", "database", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"POSTGRES_WORKSPACE",
|
||||||
|
config.get("postgres", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_oracle_config():
|
||||||
|
return {
|
||||||
|
"user": os.environ.get(
|
||||||
|
"ORACLE_USER",
|
||||||
|
config.get("oracle", "user", fallback=None),
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
config.get("oracle", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"dsn": os.environ.get(
|
||||||
|
"ORACLE_DSN",
|
||||||
|
config.get("oracle", "dsn", fallback=None),
|
||||||
|
),
|
||||||
|
"config_dir": os.environ.get(
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
config.get("oracle", "config_dir", fallback=None),
|
||||||
|
),
|
||||||
|
"wallet_location": os.environ.get(
|
||||||
|
"ORACLE_WALLET_LOCATION",
|
||||||
|
config.get("oracle", "wallet_location", fallback=None),
|
||||||
|
),
|
||||||
|
"wallet_password": os.environ.get(
|
||||||
|
"ORACLE_WALLET_PASSWORD",
|
||||||
|
config.get("oracle", "wallet_password", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"ORACLE_WORKSPACE",
|
||||||
|
config.get("oracle", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_tidb_config():
|
||||||
|
return {
|
||||||
|
"host": os.environ.get(
|
||||||
|
"TIDB_HOST",
|
||||||
|
config.get("tidb", "host", fallback="localhost"),
|
||||||
|
),
|
||||||
|
"port": os.environ.get(
|
||||||
|
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
|
||||||
|
),
|
||||||
|
"user": os.environ.get(
|
||||||
|
"TIDB_USER",
|
||||||
|
config.get("tidb", "user", fallback=None),
|
||||||
|
),
|
||||||
|
"password": os.environ.get(
|
||||||
|
"TIDB_PASSWORD",
|
||||||
|
config.get("tidb", "password", fallback=None),
|
||||||
|
),
|
||||||
|
"database": os.environ.get(
|
||||||
|
"TIDB_DATABASE",
|
||||||
|
config.get("tidb", "database", fallback=None),
|
||||||
|
),
|
||||||
|
"workspace": os.environ.get(
|
||||||
|
"TIDB_WORKSPACE",
|
||||||
|
config.get("tidb", "workspace", fallback="default"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# Create the optional API key dependency
|
# Create the optional API key dependency
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
optional_api_key = get_api_key_dependency(api_key)
|
||||||
|
|
||||||
@@ -872,10 +1049,10 @@ def create_app(args):
|
|||||||
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
||||||
else {},
|
else {},
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
kv_storage=rag_storage_config.KV_STORAGE,
|
kv_storage=args.kv_storage,
|
||||||
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
graph_storage=args.graph_storage,
|
||||||
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
vector_storage=args.vector_storage,
|
||||||
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
doc_status_storage=args.doc_status_storage,
|
||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
@@ -903,10 +1080,10 @@ def create_app(args):
|
|||||||
llm_model_max_async=args.max_async,
|
llm_model_max_async=args.max_async,
|
||||||
llm_model_max_token_size=args.max_tokens,
|
llm_model_max_token_size=args.max_tokens,
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
kv_storage=rag_storage_config.KV_STORAGE,
|
kv_storage=args.kv_storage,
|
||||||
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
graph_storage=args.graph_storage,
|
||||||
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
vector_storage=args.vector_storage,
|
||||||
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
doc_status_storage=args.doc_status_storage,
|
||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
@@ -920,6 +1097,18 @@ def create_app(args):
|
|||||||
namespace_prefix=args.namespace_prefix,
|
namespace_prefix=args.namespace_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Collect all storage instances
|
||||||
|
storage_instances = [
|
||||||
|
("full_docs", rag.full_docs),
|
||||||
|
("text_chunks", rag.text_chunks),
|
||||||
|
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
|
||||||
|
("entities_vdb", rag.entities_vdb),
|
||||||
|
("relationships_vdb", rag.relationships_vdb),
|
||||||
|
("chunks_vdb", rag.chunks_vdb),
|
||||||
|
("doc_status", rag.doc_status),
|
||||||
|
("llm_response_cache", rag.llm_response_cache),
|
||||||
|
]
|
||||||
|
|
||||||
async def index_file(file_path: Union[str, Path]) -> None:
|
async def index_file(file_path: Union[str, Path]) -> None:
|
||||||
"""Index all files inside the folder with support for multiple file formats
|
"""Index all files inside the folder with support for multiple file formats
|
||||||
|
|
||||||
@@ -1100,7 +1289,7 @@ def create_app(args):
|
|||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
only_need_context=request.only_need_context,
|
only_need_context=request.only_need_context,
|
||||||
top_k=args.top_k,
|
top_k=global_top_k,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1142,7 +1331,7 @@ def create_app(args):
|
|||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
stream=True,
|
stream=True,
|
||||||
only_need_context=request.only_need_context,
|
only_need_context=request.only_need_context,
|
||||||
top_k=args.top_k,
|
top_k=global_top_k,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1432,7 +1621,7 @@ def create_app(args):
|
|||||||
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
|
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
|
||||||
|
|
||||||
# Add Ollama API routes
|
# Add Ollama API routes
|
||||||
ollama_api = OllamaAPI(rag)
|
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||||
app.include_router(ollama_api.router, prefix="/api")
|
app.include_router(ollama_api.router, prefix="/api")
|
||||||
|
|
||||||
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
||||||
@@ -1460,10 +1649,10 @@ def create_app(args):
|
|||||||
"embedding_binding_host": args.embedding_binding_host,
|
"embedding_binding_host": args.embedding_binding_host,
|
||||||
"embedding_model": args.embedding_model,
|
"embedding_model": args.embedding_model,
|
||||||
"max_tokens": args.max_tokens,
|
"max_tokens": args.max_tokens,
|
||||||
"kv_storage": rag_storage_config.KV_STORAGE,
|
"kv_storage": args.kv_storage,
|
||||||
"doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
|
"doc_status_storage": args.doc_status_storage,
|
||||||
"graph_storage": rag_storage_config.GRAPH_STORAGE,
|
"graph_storage": args.graph_storage,
|
||||||
"vector_storage": rag_storage_config.VECTOR_STORAGE,
|
"vector_storage": args.vector_storage,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -148,9 +148,10 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
|||||||
|
|
||||||
|
|
||||||
class OllamaAPI:
|
class OllamaAPI:
|
||||||
def __init__(self, rag: LightRAG):
|
def __init__(self, rag: LightRAG, top_k: int = 60):
|
||||||
self.rag = rag
|
self.rag = rag
|
||||||
self.ollama_server_infos = ollama_server_infos
|
self.ollama_server_infos = ollama_server_infos
|
||||||
|
self.top_k = top_k
|
||||||
self.router = APIRouter()
|
self.router = APIRouter()
|
||||||
self.setup_routes()
|
self.setup_routes()
|
||||||
|
|
||||||
@@ -381,7 +382,7 @@ class OllamaAPI:
|
|||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
"only_need_context": False,
|
"only_need_context": False,
|
||||||
"conversation_history": conversation_history,
|
"conversation_history": conversation_history,
|
||||||
"top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50,
|
"top_k": self.top_k,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@@ -75,8 +75,8 @@ class AGEStorage(BaseGraphStorage):
|
|||||||
.replace("'", "\\'")
|
.replace("'", "\\'")
|
||||||
)
|
)
|
||||||
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
|
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
|
||||||
PORT = int(os.environ["AGE_POSTGRES_PORT"])
|
PORT = os.environ.get("AGE_POSTGRES_PORT", "8529")
|
||||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||||
|
|
||||||
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
||||||
|
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -13,15 +12,17 @@ from lightrag.utils import logger
|
|||||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
"""ChromaDB vector storage implementation."""
|
"""ChromaDB vector storage implementation."""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
try:
|
try:
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
user_collection_settings = config.get("collection_settings", {})
|
user_collection_settings = config.get("collection_settings", {})
|
||||||
# Default HNSW index settings for ChromaDB
|
# Default HNSW index settings for ChromaDB
|
||||||
|
@@ -23,14 +23,17 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Grab config values if available
|
# Grab config values if available
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
# Where to save index file if you want persistent storage
|
# Where to save index file if you want persistent storage
|
||||||
self._faiss_index_file = os.path.join(
|
self._faiss_index_file = os.path.join(
|
||||||
|
@@ -47,7 +47,9 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
# All vertices will have graph={GRAPH} property, so that we can
|
# All vertices will have graph={GRAPH} property, so that we can
|
||||||
# have several logical graphs for one source
|
# have several logical graphs for one source
|
||||||
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
|
GRAPH = GremlinStorage._to_value_map(
|
||||||
|
os.environ.get("GREMLIN_GRAPH", "LightRAG")
|
||||||
|
)
|
||||||
|
|
||||||
self.graph_name = GRAPH
|
self.graph_name = GRAPH
|
||||||
|
|
||||||
|
@@ -5,16 +5,22 @@ from dataclasses import dataclass
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
from ..base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("pymilvus"):
|
if not pm.is_installed("pymilvus"):
|
||||||
pm.install("pymilvus")
|
pm.install("pymilvus")
|
||||||
from pymilvus import MilvusClient
|
from pymilvus import MilvusClient
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: MilvusClient, collection_name: str, **kwargs
|
client: MilvusClient, collection_name: str, **kwargs
|
||||||
@@ -26,15 +32,37 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
|
if cosine_threshold is None:
|
||||||
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client = MilvusClient(
|
self._client = MilvusClient(
|
||||||
uri=os.environ.get(
|
uri=os.environ.get(
|
||||||
"MILVUS_URI",
|
"MILVUS_URI",
|
||||||
os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
|
config.get(
|
||||||
|
"milvus",
|
||||||
|
"uri",
|
||||||
|
fallback=os.path.join(
|
||||||
|
self.global_config["working_dir"], "milvus_lite.db"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
user=os.environ.get(
|
||||||
|
"MILVUS_USER", config.get("milvus", "user", fallback=None)
|
||||||
|
),
|
||||||
|
password=os.environ.get(
|
||||||
|
"MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)
|
||||||
|
),
|
||||||
|
token=os.environ.get(
|
||||||
|
"MILVUS_TOKEN", config.get("milvus", "token", fallback=None)
|
||||||
|
),
|
||||||
|
db_name=os.environ.get(
|
||||||
|
"MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)
|
||||||
),
|
),
|
||||||
user=os.environ.get("MILVUS_USER", ""),
|
|
||||||
password=os.environ.get("MILVUS_PASSWORD", ""),
|
|
||||||
token=os.environ.get("MILVUS_TOKEN", ""),
|
|
||||||
db_name=os.environ.get("MILVUS_DB_NAME", ""),
|
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||||
@@ -85,7 +113,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||||||
data=embedding,
|
data=embedding,
|
||||||
limit=top_k,
|
limit=top_k,
|
||||||
output_fields=list(self.meta_fields),
|
output_fields=list(self.meta_fields),
|
||||||
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
|
search_params={
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
"params": {"radius": self.cosine_better_than_threshold},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
print(results)
|
print(results)
|
||||||
return [
|
return [
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
import configparser
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
|
|
||||||
if not pm.is_installed("pymongo"):
|
if not pm.is_installed("pymongo"):
|
||||||
@@ -12,7 +12,6 @@ if not pm.is_installed("motor"):
|
|||||||
pm.install("motor")
|
pm.install("motor")
|
||||||
|
|
||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
|
||||||
@@ -27,13 +26,27 @@ from ..namespace import NameSpace, is_namespace
|
|||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
client = MongoClient(
|
client = MongoClient(
|
||||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
os.environ.get(
|
||||||
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
database = client.get_database(
|
||||||
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
|
|
||||||
self._data = database.get_collection(self.namespace)
|
self._data = database.get_collection(self.namespace)
|
||||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||||
|
|
||||||
@@ -173,10 +186,25 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
self.client = AsyncIOMotorClient(
|
self.client = AsyncIOMotorClient(
|
||||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
os.environ.get(
|
||||||
|
"MONGO_URI",
|
||||||
|
config.get(
|
||||||
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")]
|
self.db = self.client[
|
||||||
self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")]
|
os.environ.get(
|
||||||
|
"MONGO_DATABASE",
|
||||||
|
mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.collection = self.db[
|
||||||
|
os.environ.get(
|
||||||
|
"MONGO_KG_COLLECTION",
|
||||||
|
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
@@ -73,16 +73,19 @@ from lightrag.base import (
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize lock only for file operations
|
# Initialize lock only for file operations
|
||||||
self._save_lock = asyncio.Lock()
|
self._save_lock = asyncio.Lock()
|
||||||
# Use global config value if specified, otherwise use default
|
# Use global config value if specified, otherwise use default
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
@@ -139,9 +142,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
async def query(self, query: str, top_k=5):
|
async def query(self, query: str, top_k=5):
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
embedding = embedding[0]
|
embedding = embedding[0]
|
||||||
logger.info(
|
|
||||||
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
|
|
||||||
)
|
|
||||||
results = self._client.query(
|
results = self._client.query(
|
||||||
query=embedding,
|
query=embedding,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
@@ -5,6 +5,7 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union, Tuple, List, Dict
|
from typing import Any, Union, Tuple, List, Dict
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
pm.install("neo4j")
|
pm.install("neo4j")
|
||||||
@@ -28,6 +29,10 @@ from ..base import BaseGraphStorage
|
|||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -42,13 +47,22 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
self._driver = None
|
self._driver = None
|
||||||
self._driver_lock = asyncio.Lock()
|
self._driver_lock = asyncio.Lock()
|
||||||
URI = os.environ["NEO4J_URI"]
|
|
||||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)]
|
||||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
USERNAME = os.environ[
|
||||||
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
|
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||||
|
]
|
||||||
|
PASSWORD = os.environ[
|
||||||
|
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
|
||||||
|
]
|
||||||
|
MAX_CONNECTION_POOL_SIZE = os.environ.get(
|
||||||
|
"NEO4J_MAX_CONNECTION_POOL_SIZE",
|
||||||
|
config.get("neo4j", "connection_pool_size", fallback=800),
|
||||||
|
)
|
||||||
DATABASE = os.environ.get(
|
DATABASE = os.environ.get(
|
||||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||||
URI, auth=(USERNAME, PASSWORD)
|
URI, auth=(USERNAME, PASSWORD)
|
||||||
)
|
)
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
import array
|
import array
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
|
|
||||||
# import html
|
# import html
|
||||||
# import os
|
# import os
|
||||||
@@ -172,8 +171,8 @@ class OracleDB:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleKVStorage(BaseKVStorage):
|
class OracleKVStorage(BaseKVStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: OracleDB = None
|
# db: OracleDB
|
||||||
meta_fields = None
|
meta_fields = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -318,16 +317,18 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
db: OracleDB = None
|
# db: OracleDB
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict]):
|
||||||
"""向向量数据库中插入数据"""
|
"""向向量数据库中插入数据"""
|
||||||
@@ -361,7 +362,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleGraphStorage(BaseGraphStorage):
|
class OracleGraphStorage(BaseGraphStorage):
|
||||||
"""基于Oracle的图存储模块"""
|
# db instance must be injected before use
|
||||||
|
# db: OracleDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""从graphml文件加载图"""
|
"""从graphml文件加载图"""
|
||||||
|
@@ -177,7 +177,8 @@ class PostgreSQLDB:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGKVStorage(BaseKVStorage):
|
class PGKVStorage(BaseKVStorage):
|
||||||
db: PostgreSQLDB = None
|
# db instance must be injected before use
|
||||||
|
# db: PostgreSQLDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
@@ -296,16 +297,19 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
# db instance must be injected before use
|
||||||
db: PostgreSQLDB = None
|
# db: PostgreSQLDB
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
def _upsert_chunks(self, item: dict):
|
def _upsert_chunks(self, item: dict):
|
||||||
try:
|
try:
|
||||||
@@ -416,20 +420,14 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGDocStatusStorage(DocStatusStorage):
|
class PGDocStatusStorage(DocStatusStorage):
|
||||||
"""PostgreSQL implementation of document status storage"""
|
# db instance must be injected before use
|
||||||
|
# db: PostgreSQLDB
|
||||||
db: PostgreSQLDB = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
keys = ",".join([f"'{_id}'" for _id in data])
|
keys = ",".join([f"'{_id}'" for _id in data])
|
||||||
sql = (
|
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
||||||
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})"
|
result = await self.db.query(sql, multirows=True)
|
||||||
)
|
|
||||||
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
|
||||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||||
if result is None:
|
if result is None:
|
||||||
return set(data)
|
return set(data)
|
||||||
@@ -585,19 +583,15 @@ class PGGraphQueryException(Exception):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGGraphStorage(BaseGraphStorage):
|
class PGGraphStorage(BaseGraphStorage):
|
||||||
db: PostgreSQLDB = None
|
# db instance must be injected before use
|
||||||
|
# db: PostgreSQLDB
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_nx_graph(file_name):
|
def load_nx_graph(file_name):
|
||||||
print("no preloading of graph with AGE in production")
|
print("no preloading of graph with AGE in production")
|
||||||
|
|
||||||
def __init__(self, namespace, global_config, embedding_func):
|
def __post_init__(self):
|
||||||
super().__init__(
|
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||||
namespace=namespace,
|
|
||||||
global_config=global_config,
|
|
||||||
embedding_func=embedding_func,
|
|
||||||
)
|
|
||||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
|
||||||
self._node_embed_algorithms = {
|
self._node_embed_algorithms = {
|
||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
@@ -1137,7 +1131,7 @@ TABLES = {
|
|||||||
"ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
|
"ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
|
||||||
workspace varchar(255) NOT NULL,
|
workspace varchar(255) NOT NULL,
|
||||||
id varchar(255) NOT NULL,
|
id varchar(255) NOT NULL,
|
||||||
content TEXT,
|
content TEXT NULL,
|
||||||
content_summary varchar(255) NULL,
|
content_summary varchar(255) NULL,
|
||||||
content_length int4 NULL,
|
content_length int4 NULL,
|
||||||
chunks_count int4 NULL,
|
chunks_count int4 NULL,
|
||||||
|
@@ -1,138 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import pipmaster as pm
|
|
||||||
|
|
||||||
if not pm.is_installed("psycopg-pool"):
|
|
||||||
pm.install("psycopg-pool")
|
|
||||||
pm.install("psycopg[binary,pool]")
|
|
||||||
if not pm.is_installed("asyncpg"):
|
|
||||||
pm.install("asyncpg")
|
|
||||||
|
|
||||||
import asyncpg
|
|
||||||
import psycopg
|
|
||||||
from psycopg_pool import AsyncConnectionPool
|
|
||||||
|
|
||||||
from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
|
|
||||||
from ..namespace import NameSpace
|
|
||||||
|
|
||||||
DB = "rag"
|
|
||||||
USER = "rag"
|
|
||||||
PASSWORD = "rag"
|
|
||||||
HOST = "localhost"
|
|
||||||
PORT = "15432"
|
|
||||||
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
|
||||||
import asyncio.windows_events
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
|
|
||||||
|
|
||||||
async def get_pool():
|
|
||||||
return await asyncpg.create_pool(
|
|
||||||
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
|
||||||
min_size=10,
|
|
||||||
max_size=10,
|
|
||||||
max_queries=5000,
|
|
||||||
max_inactive_connection_lifetime=300.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def main1():
|
|
||||||
connection_string = (
|
|
||||||
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
|
||||||
)
|
|
||||||
pool = AsyncConnectionPool(connection_string, open=False)
|
|
||||||
await pool.open()
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn = await pool.getconn(timeout=10)
|
|
||||||
async with conn.cursor() as curs:
|
|
||||||
try:
|
|
||||||
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
|
||||||
await curs.execute("SELECT create_graph('dickens-2')")
|
|
||||||
await conn.commit()
|
|
||||||
print("create_graph success")
|
|
||||||
except (
|
|
||||||
psycopg.errors.InvalidSchemaName,
|
|
||||||
psycopg.errors.UniqueViolation,
|
|
||||||
):
|
|
||||||
print("create_graph already exists")
|
|
||||||
await conn.rollback()
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
db = PostgreSQLDB(
|
|
||||||
config={
|
|
||||||
"host": "localhost",
|
|
||||||
"port": 15432,
|
|
||||||
"user": "rag",
|
|
||||||
"password": "rag",
|
|
||||||
"database": "r1",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def query_with_age():
|
|
||||||
await db.initdb()
|
|
||||||
graph = PGGraphStorage(
|
|
||||||
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
|
|
||||||
global_config={},
|
|
||||||
embedding_func=None,
|
|
||||||
)
|
|
||||||
graph.db = db
|
|
||||||
res = await graph.get_node('"A CHRISTMAS CAROL"')
|
|
||||||
print("Node is: ", res)
|
|
||||||
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
|
|
||||||
print("Edge is: ", res)
|
|
||||||
res = await graph.get_node_edges('"SCROOGE"')
|
|
||||||
print("Node Edges are: ", res)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_edge_with_age():
|
|
||||||
await db.initdb()
|
|
||||||
graph = PGGraphStorage(
|
|
||||||
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
|
|
||||||
global_config={},
|
|
||||||
embedding_func=None,
|
|
||||||
)
|
|
||||||
graph.db = db
|
|
||||||
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
|
|
||||||
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
|
|
||||||
await graph.upsert_edge(
|
|
||||||
'"THE CRATCHITS"',
|
|
||||||
'"THE GIRLS"',
|
|
||||||
edge_data={
|
|
||||||
"weight": 7.0,
|
|
||||||
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
|
|
||||||
"keywords": '"family, collective effort"',
|
|
||||||
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
|
|
||||||
print("Edge is: ", res)
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
pool = await get_pool()
|
|
||||||
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
|
||||||
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
|
||||||
async with pool.acquire() as conn:
|
|
||||||
try:
|
|
||||||
await conn.execute(
|
|
||||||
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
|
|
||||||
)
|
|
||||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
|
||||||
print("create_graph already exists")
|
|
||||||
# stmt = await conn.prepare(sql)
|
|
||||||
row = await conn.fetch(sql)
|
|
||||||
print("row is: ", row)
|
|
||||||
|
|
||||||
row = await conn.fetchrow("select '100'::int + 200 as result")
|
|
||||||
print(row) # <Record result=300>
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(query_with_age())
|
|
@@ -5,11 +5,10 @@ from dataclasses import dataclass
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("qdrant_client"):
|
if not pm.is_installed("qdrant_client"):
|
||||||
pm.install("qdrant_client")
|
pm.install("qdrant_client")
|
||||||
@@ -17,6 +16,10 @@ if not pm.is_installed("qdrant_client"):
|
|||||||
from qdrant_client import QdrantClient, models
|
from qdrant_client import QdrantClient, models
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id_for_qdrant(
|
def compute_mdhash_id_for_qdrant(
|
||||||
content: str, prefix: str = "", style: str = "simple"
|
content: str, prefix: str = "", style: str = "simple"
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -47,6 +50,8 @@ def compute_mdhash_id_for_qdrant(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: QdrantClient, collection_name: str, **kwargs
|
client: QdrantClient, collection_name: str, **kwargs
|
||||||
@@ -56,9 +61,21 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
client.create_collection(collection_name, **kwargs)
|
client.create_collection(collection_name, **kwargs)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
|
if cosine_threshold is None:
|
||||||
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client = QdrantClient(
|
self._client = QdrantClient(
|
||||||
url=os.environ.get("QDRANT_URL"),
|
url=os.environ.get(
|
||||||
api_key=os.environ.get("QDRANT_API_KEY", None),
|
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
|
||||||
|
),
|
||||||
|
api_key=os.environ.get(
|
||||||
|
"QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||||
@@ -122,4 +139,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
limit=top_k,
|
limit=top_k,
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
)
|
)
|
||||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
logger.debug(f"query result: {results}")
|
||||||
|
# 添加余弦相似度过滤
|
||||||
|
filtered_results = [
|
||||||
|
dp for dp in results if dp.score >= self.cosine_better_than_threshold
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
|
||||||
|
]
|
||||||
|
@@ -3,6 +3,7 @@ from typing import Any, Union
|
|||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
import configparser
|
||||||
|
|
||||||
if not pm.is_installed("redis"):
|
if not pm.is_installed("redis"):
|
||||||
pm.install("redis")
|
pm.install("redis")
|
||||||
@@ -14,10 +15,16 @@ from lightrag.base import BaseKVStorage
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RedisKVStorage(BaseKVStorage):
|
class RedisKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
|
redis_url = os.environ.get(
|
||||||
|
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||||
|
)
|
||||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||||
logger.info(f"Use Redis as KV {self.namespace}")
|
logger.info(f"Use Redis as KV {self.namespace}")
|
||||||
|
|
||||||
|
@@ -101,7 +101,9 @@ class TiDB:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBKVStorage(BaseKVStorage):
|
class TiDBKVStorage(BaseKVStorage):
|
||||||
# should pass db object to self.db
|
# db instance must be injected before use
|
||||||
|
# db: TiDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
@@ -208,18 +210,22 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
# db instance must be injected before use
|
||||||
|
# db: TiDB
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError(
|
||||||
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||||
|
)
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||||
"""Search from tidb vector"""
|
"""Search from tidb vector"""
|
||||||
@@ -329,6 +335,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBGraphStorage(BaseGraphStorage):
|
class TiDBGraphStorage(BaseGraphStorage):
|
||||||
|
# db instance must be injected before use
|
||||||
|
# db: TiDB
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import configparser
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -36,6 +37,111 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
from .types import KnowledgeGraph
|
from .types import KnowledgeGraph
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
# Storage type and implementation compatibility validation table
|
||||||
|
STORAGE_IMPLEMENTATIONS = {
|
||||||
|
"KV_STORAGE": {
|
||||||
|
"implementations": [
|
||||||
|
"JsonKVStorage",
|
||||||
|
"MongoKVStorage",
|
||||||
|
"RedisKVStorage",
|
||||||
|
"TiDBKVStorage",
|
||||||
|
"PGKVStorage",
|
||||||
|
"OracleKVStorage",
|
||||||
|
],
|
||||||
|
"required_methods": ["get_by_id", "upsert"],
|
||||||
|
},
|
||||||
|
"GRAPH_STORAGE": {
|
||||||
|
"implementations": [
|
||||||
|
"NetworkXStorage",
|
||||||
|
"Neo4JStorage",
|
||||||
|
"MongoGraphStorage",
|
||||||
|
"TiDBGraphStorage",
|
||||||
|
"AGEStorage",
|
||||||
|
"GremlinStorage",
|
||||||
|
"PGGraphStorage",
|
||||||
|
"OracleGraphStorage",
|
||||||
|
],
|
||||||
|
"required_methods": ["upsert_node", "upsert_edge"],
|
||||||
|
},
|
||||||
|
"VECTOR_STORAGE": {
|
||||||
|
"implementations": [
|
||||||
|
"NanoVectorDBStorage",
|
||||||
|
"MilvusVectorDBStorge",
|
||||||
|
"ChromaVectorDBStorage",
|
||||||
|
"TiDBVectorDBStorage",
|
||||||
|
"PGVectorStorage",
|
||||||
|
"FaissVectorDBStorage",
|
||||||
|
"QdrantVectorDBStorage",
|
||||||
|
"OracleVectorDBStorage",
|
||||||
|
],
|
||||||
|
"required_methods": ["query", "upsert"],
|
||||||
|
},
|
||||||
|
"DOC_STATUS_STORAGE": {
|
||||||
|
"implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"],
|
||||||
|
"required_methods": ["get_pending_docs"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Storage implementation environment variable without default value
|
||||||
|
STORAGE_ENV_REQUIREMENTS = {
|
||||||
|
# KV Storage Implementations
|
||||||
|
"JsonKVStorage": [],
|
||||||
|
"MongoKVStorage": [],
|
||||||
|
"RedisKVStorage": ["REDIS_URI"],
|
||||||
|
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||||
|
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
|
"OracleKVStorage": [
|
||||||
|
"ORACLE_DSN",
|
||||||
|
"ORACLE_USER",
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
],
|
||||||
|
# Graph Storage Implementations
|
||||||
|
"NetworkXStorage": [],
|
||||||
|
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||||
|
"MongoGraphStorage": [],
|
||||||
|
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||||
|
"AGEStorage": [
|
||||||
|
"AGE_POSTGRES_DB",
|
||||||
|
"AGE_POSTGRES_USER",
|
||||||
|
"AGE_POSTGRES_PASSWORD",
|
||||||
|
],
|
||||||
|
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||||
|
"PGGraphStorage": [
|
||||||
|
"POSTGRES_USER",
|
||||||
|
"POSTGRES_PASSWORD",
|
||||||
|
"POSTGRES_DATABASE",
|
||||||
|
],
|
||||||
|
"OracleGraphStorage": [
|
||||||
|
"ORACLE_DSN",
|
||||||
|
"ORACLE_USER",
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
],
|
||||||
|
# Vector Storage Implementations
|
||||||
|
"NanoVectorDBStorage": [],
|
||||||
|
"MilvusVectorDBStorge": [],
|
||||||
|
"ChromaVectorDBStorage": [],
|
||||||
|
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||||
|
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
|
"FaissVectorDBStorage": [],
|
||||||
|
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||||
|
"OracleVectorDBStorage": [
|
||||||
|
"ORACLE_DSN",
|
||||||
|
"ORACLE_USER",
|
||||||
|
"ORACLE_PASSWORD",
|
||||||
|
"ORACLE_CONFIG_DIR",
|
||||||
|
],
|
||||||
|
# Document Status Storage Implementations
|
||||||
|
"JsonDocStatusStorage": [],
|
||||||
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
|
"MongoDocStatusStorage": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Storage implementation module mapping
|
||||||
STORAGES = {
|
STORAGES = {
|
||||||
"NetworkXStorage": ".kg.networkx_impl",
|
"NetworkXStorage": ".kg.networkx_impl",
|
||||||
"JsonKVStorage": ".kg.json_kv_impl",
|
"JsonKVStorage": ".kg.json_kv_impl",
|
||||||
@@ -140,6 +246,9 @@ class LightRAG:
|
|||||||
graph_storage: str = field(default="NetworkXStorage")
|
graph_storage: str = field(default="NetworkXStorage")
|
||||||
"""Storage backend for knowledge graphs."""
|
"""Storage backend for knowledge graphs."""
|
||||||
|
|
||||||
|
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||||
|
"""Storage type for tracking document processing statuses."""
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
current_log_level = logger.level
|
current_log_level = logger.level
|
||||||
log_level: int = field(default=current_log_level)
|
log_level: int = field(default=current_log_level)
|
||||||
@@ -236,9 +345,6 @@ class LightRAG:
|
|||||||
convert_response_to_json
|
convert_response_to_json
|
||||||
)
|
)
|
||||||
|
|
||||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
|
||||||
"""Storage type for tracking document processing statuses."""
|
|
||||||
|
|
||||||
# Custom Chunking Function
|
# Custom Chunking Function
|
||||||
chunking_func: Callable[
|
chunking_func: Callable[
|
||||||
[
|
[
|
||||||
@@ -252,6 +358,46 @@ class LightRAG:
|
|||||||
list[dict[str, Any]],
|
list[dict[str, Any]],
|
||||||
] = chunking_by_token_size
|
] = chunking_by_token_size
|
||||||
|
|
||||||
|
def verify_storage_implementation(
|
||||||
|
self, storage_type: str, storage_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Verify if storage implementation is compatible with specified storage type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||||
|
storage_name: Storage implementation name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If storage implementation is incompatible or missing required methods
|
||||||
|
"""
|
||||||
|
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||||
|
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||||
|
|
||||||
|
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||||
|
if storage_name not in storage_info["implementations"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||||
|
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||||
|
"""Check if all required environment variables for storage implementation exist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_name: Storage implementation name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required environment variables are missing
|
||||||
|
"""
|
||||||
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||||
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||||
|
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Storage implementation '{storage_name}' requires the following "
|
||||||
|
f"environment variables: {', '.join(missing_vars)}"
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
os.makedirs(self.log_dir, exist_ok=True)
|
os.makedirs(self.log_dir, exist_ok=True)
|
||||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
log_file = os.path.join(self.log_dir, "lightrag.log")
|
||||||
@@ -263,6 +409,29 @@ class LightRAG:
|
|||||||
logger.info(f"Creating working directory {self.working_dir}")
|
logger.info(f"Creating working directory {self.working_dir}")
|
||||||
os.makedirs(self.working_dir)
|
os.makedirs(self.working_dir)
|
||||||
|
|
||||||
|
# Verify storage implementation compatibility and environment variables
|
||||||
|
storage_configs = [
|
||||||
|
("KV_STORAGE", self.kv_storage),
|
||||||
|
("VECTOR_STORAGE", self.vector_storage),
|
||||||
|
("GRAPH_STORAGE", self.graph_storage),
|
||||||
|
("DOC_STATUS_STORAGE", self.doc_status_storage),
|
||||||
|
]
|
||||||
|
|
||||||
|
for storage_type, storage_name in storage_configs:
|
||||||
|
# Verify storage implementation compatibility
|
||||||
|
self.verify_storage_implementation(storage_type, storage_name)
|
||||||
|
# Check environment variables
|
||||||
|
self.check_storage_env_vars(storage_name)
|
||||||
|
|
||||||
|
# Ensure vector_db_storage_cls_kwargs has required fields
|
||||||
|
default_vector_db_kwargs = {
|
||||||
|
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||||
|
}
|
||||||
|
self.vector_db_storage_cls_kwargs = {
|
||||||
|
**default_vector_db_kwargs,
|
||||||
|
**self.vector_db_storage_cls_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
# show config
|
# show config
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||||
@@ -296,10 +465,8 @@ class LightRAG:
|
|||||||
self.graph_storage_cls, global_config=global_config
|
self.graph_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
|
||||||
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
|
# Initialize document status storage
|
||||||
namespace=self.namespace_prefix + "json_doc_status_storage",
|
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||||
embedding_func=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
@@ -308,9 +475,6 @@ class LightRAG:
|
|||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
####
|
|
||||||
# add embedding func by walter
|
|
||||||
####
|
|
||||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||||
@@ -329,9 +493,6 @@ class LightRAG:
|
|||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
####
|
|
||||||
# add embedding func by walter over
|
|
||||||
####
|
|
||||||
|
|
||||||
self.entities_vdb = self.vector_db_storage_cls(
|
self.entities_vdb = self.vector_db_storage_cls(
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
@@ -354,6 +515,14 @@ class LightRAG:
|
|||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize document status storage
|
||||||
|
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||||
|
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
||||||
|
global_config=global_config,
|
||||||
|
embedding_func=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# What's for, Is this nessisary ?
|
||||||
if self.llm_response_cache and hasattr(
|
if self.llm_response_cache and hasattr(
|
||||||
self.llm_response_cache, "global_config"
|
self.llm_response_cache, "global_config"
|
||||||
):
|
):
|
||||||
@@ -374,14 +543,6 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize document status storage
|
|
||||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
|
||||||
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
|
||||||
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
|
||||||
global_config=global_config,
|
|
||||||
embedding_func=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_graph_labels(self):
|
async def get_graph_labels(self):
|
||||||
text = await self.chunk_entity_relation_graph.get_all_labels()
|
text = await self.chunk_entity_relation_graph.get_all_labels()
|
||||||
return text
|
return text
|
||||||
@@ -399,7 +560,8 @@ class LightRAG:
|
|||||||
return storage_class
|
return storage_class
|
||||||
|
|
||||||
def set_storage_client(self, db_client):
|
def set_storage_client(self, db_client):
|
||||||
# Now only tested on Oracle Database
|
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
||||||
|
# Inject db to storage implementation (only tested on Oracle Database)
|
||||||
for storage in [
|
for storage in [
|
||||||
self.vector_db_storage_cls,
|
self.vector_db_storage_cls,
|
||||||
self.graph_storage_cls,
|
self.graph_storage_cls,
|
||||||
|
@@ -1055,6 +1055,9 @@ async def _get_node_data(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
# get similar entities
|
# get similar entities
|
||||||
|
logger.info(
|
||||||
|
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
||||||
|
)
|
||||||
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return "", "", ""
|
return "", "", ""
|
||||||
@@ -1270,6 +1273,9 @@ async def _get_edge_data(
|
|||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||||
|
)
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
|
@@ -416,7 +416,13 @@ async def get_best_cached_response(
|
|||||||
|
|
||||||
if best_similarity > similarity_threshold:
|
if best_similarity > similarity_threshold:
|
||||||
# If LLM check is enabled and all required parameters are provided
|
# If LLM check is enabled and all required parameters are provided
|
||||||
if use_llm_check and llm_func and original_prompt and best_prompt:
|
if (
|
||||||
|
use_llm_check
|
||||||
|
and llm_func
|
||||||
|
and original_prompt
|
||||||
|
and best_prompt
|
||||||
|
and best_response is not None
|
||||||
|
):
|
||||||
compare_prompt = PROMPTS["similarity_check"].format(
|
compare_prompt = PROMPTS["similarity_check"].format(
|
||||||
original_prompt=original_prompt, cached_prompt=best_prompt
|
original_prompt=original_prompt, cached_prompt=best_prompt
|
||||||
)
|
)
|
||||||
@@ -430,7 +436,9 @@ async def get_best_cached_response(
|
|||||||
best_similarity = llm_similarity
|
best_similarity = llm_similarity
|
||||||
if best_similarity < similarity_threshold:
|
if best_similarity < similarity_threshold:
|
||||||
log_data = {
|
log_data = {
|
||||||
"event": "llm_check_cache_rejected",
|
"event": "cache_rejected_by_llm",
|
||||||
|
"type": cache_type,
|
||||||
|
"mode": mode,
|
||||||
"original_question": original_prompt[:100] + "..."
|
"original_question": original_prompt[:100] + "..."
|
||||||
if len(original_prompt) > 100
|
if len(original_prompt) > 100
|
||||||
else original_prompt,
|
else original_prompt,
|
||||||
@@ -440,7 +448,8 @@ async def get_best_cached_response(
|
|||||||
"similarity_score": round(best_similarity, 4),
|
"similarity_score": round(best_similarity, 4),
|
||||||
"threshold": similarity_threshold,
|
"threshold": similarity_threshold,
|
||||||
}
|
}
|
||||||
logger.info(json.dumps(log_data, ensure_ascii=False))
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
||||||
|
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
||||||
return None
|
return None
|
||||||
except Exception as e: # Catch all possible exceptions
|
except Exception as e: # Catch all possible exceptions
|
||||||
logger.warning(f"LLM similarity check failed: {e}")
|
logger.warning(f"LLM similarity check failed: {e}")
|
||||||
@@ -451,12 +460,13 @@ async def get_best_cached_response(
|
|||||||
)
|
)
|
||||||
log_data = {
|
log_data = {
|
||||||
"event": "cache_hit",
|
"event": "cache_hit",
|
||||||
|
"type": cache_type,
|
||||||
"mode": mode,
|
"mode": mode,
|
||||||
"similarity": round(best_similarity, 4),
|
"similarity": round(best_similarity, 4),
|
||||||
"cache_id": best_cache_id,
|
"cache_id": best_cache_id,
|
||||||
"original_prompt": prompt_display,
|
"original_prompt": prompt_display,
|
||||||
}
|
}
|
||||||
logger.info(json.dumps(log_data, ensure_ascii=False))
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
||||||
return best_response
|
return best_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -534,19 +544,24 @@ async def handle_cache(
|
|||||||
cache_type=cache_type,
|
cache_type=cache_type,
|
||||||
)
|
)
|
||||||
if best_cached_response is not None:
|
if best_cached_response is not None:
|
||||||
|
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
||||||
return best_cached_response, None, None, None
|
return best_cached_response, None, None, None
|
||||||
else:
|
else:
|
||||||
|
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
||||||
|
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
||||||
return None, quantized, min_val, max_val
|
return None, quantized, min_val, max_val
|
||||||
|
|
||||||
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
|
# For default mode or is_embedding_cache_enabled is False, use regular cache
|
||||||
# Use regular cache
|
# default mode is for extract_entities or naive query
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||||
else:
|
else:
|
||||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||||
if args_hash in mode_cache:
|
if args_hash in mode_cache:
|
||||||
|
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
||||||
return mode_cache[args_hash]["return"], None, None, None
|
return mode_cache[args_hash]["return"], None, None, None
|
||||||
|
|
||||||
|
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user