Merge branch 'main'
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.2.8"
|
||||
__version__ = "1.2.9"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
559
lightrag/api/README-zh.md
Normal file
559
lightrag/api/README-zh.md
Normal file
@@ -0,0 +1,559 @@
|
||||
# LightRAG 服务器和 Web 界面
|
||||
|
||||
LightRAG 服务器旨在提供 Web 界面和 API 支持。Web 界面便于文档索引、知识图谱探索和简单的 RAG 查询界面。LightRAG 服务器还提供了与 Ollama 兼容的接口,旨在将 LightRAG 模拟为 Ollama 聊天模型。这使得 AI 聊天机器人(如 Open WebUI)可以轻松访问 LightRAG。
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## 入门指南
|
||||
|
||||
### 安装
|
||||
|
||||
* 从 PyPI 安装
|
||||
|
||||
```bash
|
||||
pip install "lightrag-hku[api]"
|
||||
```
|
||||
|
||||
* 从源代码安装
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone https://github.com/HKUDS/lightrag.git
|
||||
|
||||
# 切换到仓库目录
|
||||
cd lightrag
|
||||
|
||||
# 如有必要,创建 Python 虚拟环境
|
||||
# 以可编辑模式安装并支持 API
|
||||
pip install -e ".[api]"
|
||||
```
|
||||
|
||||
### 启动 LightRAG 服务器前的准备
|
||||
|
||||
LightRAG 需要同时集成 LLM(大型语言模型)和嵌入模型以有效执行文档索引和查询操作。在首次部署 LightRAG 服务器之前,必须配置 LLM 和嵌入模型的设置。LightRAG 支持绑定到各种 LLM/嵌入后端:
|
||||
|
||||
* ollama
|
||||
* lollms
|
||||
* openai 或 openai 兼容
|
||||
* azure_openai
|
||||
|
||||
建议使用环境变量来配置 LightRAG 服务器。项目根目录中有一个名为 `env.example` 的示例环境变量文件。请将此文件复制到启动目录并重命名为 `.env`。之后,您可以在 `.env` 文件中修改与 LLM 和嵌入模型相关的参数。需要注意的是,LightRAG 服务器每次启动时都会将 `.env` 中的环境变量加载到系统环境变量中。由于 LightRAG 服务器会优先使用系统环境变量中的设置,如果您在通过命令行启动 LightRAG 服务器后修改了 `.env` 文件,则需要执行 `source .env` 使新设置生效。
|
||||
|
||||
以下是 LLM 和嵌入模型的一些常见设置示例:
|
||||
|
||||
* OpenAI LLM + Ollama 嵌入
|
||||
|
||||
```
|
||||
LLM_BINDING=openai
|
||||
LLM_MODEL=gpt-4o
|
||||
LLM_BINDING_HOST=https://api.openai.com/v1
|
||||
LLM_BINDING_API_KEY=your_api_key
|
||||
MAX_TOKENS=32768 # 发送给 LLM 的最大 token 数(小于模型上下文大小)
|
||||
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
EMBEDDING_DIM=1024
|
||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
```
|
||||
|
||||
* Ollama LLM + Ollama 嵌入
|
||||
|
||||
```
|
||||
LLM_BINDING=ollama
|
||||
LLM_MODEL=mistral-nemo:latest
|
||||
LLM_BINDING_HOST=http://localhost:11434
|
||||
# LLM_BINDING_API_KEY=your_api_key
|
||||
MAX_TOKENS=8192 # 发送给 LLM 的最大 token 数(基于您的 Ollama 服务器容量)
|
||||
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
EMBEDDING_DIM=1024
|
||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
```
|
||||
|
||||
### 启动 LightRAG 服务器
|
||||
|
||||
LightRAG 服务器支持两种运行模式:
|
||||
* 简单高效的 Uvicorn 模式
|
||||
|
||||
```
|
||||
lightrag-server
|
||||
```
|
||||
* 多进程 Gunicorn + Uvicorn 模式(生产模式,不支持 Windows 环境)
|
||||
|
||||
```
|
||||
lightrag-gunicorn --workers 4
|
||||
```
|
||||
`.env` 文件必须放在启动目录中。启动时,LightRAG 服务器将创建一个文档目录(默认为 `./inputs`)和一个数据目录(默认为 `./rag_storage`)。这允许您从不同目录启动多个 LightRAG 服务器实例,每个实例配置为监听不同的网络端口。
|
||||
|
||||
以下是一些常用的启动参数:
|
||||
|
||||
- `--host`:服务器监听地址(默认:0.0.0.0)
|
||||
- `--port`:服务器监听端口(默认:9621)
|
||||
- `--timeout`:LLM 请求超时时间(默认:150 秒)
|
||||
- `--log-level`:日志级别(默认:INFO)
|
||||
- --input-dir:指定要扫描文档的目录(默认:./input)
|
||||
|
||||
### 启动时自动扫描
|
||||
|
||||
当使用 `--auto-scan-at-startup` 参数启动任何服务器时,系统将自动:
|
||||
|
||||
1. 扫描输入目录中的新文件
|
||||
2. 为尚未在数据库中的新文档建立索引
|
||||
3. 使所有内容立即可用于 RAG 查询
|
||||
|
||||
> `--input-dir` 参数指定要扫描的输入目录。您可以从 webui 触发输入目录扫描。
|
||||
|
||||
### Gunicorn + Uvicorn 的多工作进程
|
||||
|
||||
LightRAG 服务器可以在 `Gunicorn + Uvicorn` 预加载模式下运行。Gunicorn 的多工作进程(多进程)功能可以防止文档索引任务阻塞 RAG 查询。使用 CPU 密集型文档提取工具(如 docling)在纯 Uvicorn 模式下可能会导致整个系统被阻塞。
|
||||
|
||||
虽然 LightRAG 服务器使用一个工作进程来处理文档索引流程,但通过 Uvicorn 的异步任务支持,可以并行处理多个文件。文档索引速度的瓶颈主要在于 LLM。如果您的 LLM 支持高并发,您可以通过增加 LLM 的并发级别来加速文档索引。以下是几个与并发处理相关的环境变量及其默认值:
|
||||
|
||||
```
|
||||
WORKERS=2 # 工作进程数,不大于 (2 x 核心数) + 1
|
||||
MAX_PARALLEL_INSERT=2 # 一批中并行处理的文件数
|
||||
MAX_ASYNC=4 # LLM 的最大并发请求数
|
||||
```
|
||||
|
||||
### 将 Lightrag 安装为 Linux 服务
|
||||
|
||||
从示例文件 `lightrag.sevice.example` 创建您的服务文件 `lightrag.sevice`。修改服务文件中的 WorkingDirectory 和 ExecStart:
|
||||
|
||||
```text
|
||||
Description=LightRAG Ollama Service
|
||||
WorkingDirectory=<lightrag 安装目录>
|
||||
ExecStart=<lightrag 安装目录>/lightrag/api/lightrag-api
|
||||
```
|
||||
|
||||
修改您的服务启动脚本:`lightrag-api`。根据需要更改 python 虚拟环境激活命令:
|
||||
|
||||
```shell
|
||||
#!/bin/bash
|
||||
|
||||
# 您的 python 虚拟环境激活命令
|
||||
source /home/netman/lightrag-xyj/venv/bin/activate
|
||||
# 启动 lightrag api 服务器
|
||||
lightrag-server
|
||||
```
|
||||
|
||||
安装 LightRAG 服务。如果您的系统是 Ubuntu,以下命令将生效:
|
||||
|
||||
```shell
|
||||
sudo cp lightrag.service /etc/systemd/system/
|
||||
sudo systemctl daemon-reload
|
||||
sudo systemctl start lightrag.service
|
||||
sudo systemctl status lightrag.service
|
||||
sudo systemctl enable lightrag.service
|
||||
```
|
||||
|
||||
## Ollama 模拟
|
||||
|
||||
我们为 LightRAG 提供了 Ollama 兼容接口,旨在将 LightRAG 模拟为 Ollama 聊天模型。这使得支持 Ollama 的 AI 聊天前端(如 Open WebUI)可以轻松访问 LightRAG。
|
||||
|
||||
### 将 Open WebUI 连接到 LightRAG
|
||||
|
||||
启动 lightrag-server 后,您可以在 Open WebUI 管理面板中添加 Ollama 类型的连接。然后,一个名为 lightrag:latest 的模型将出现在 Open WebUI 的模型管理界面中。用户随后可以通过聊天界面向 LightRAG 发送查询。对于这种用例,最好将 LightRAG 安装为服务。
|
||||
|
||||
Open WebUI 使用 LLM 来执行会话标题和会话关键词生成任务。因此,Ollama 聊天补全 API 会检测并将 OpenWebUI 会话相关请求直接转发给底层 LLM。Open WebUI 的截图:
|
||||
|
||||

|
||||
|
||||
### 在聊天中选择查询模式
|
||||
|
||||
查询字符串中的查询前缀可以决定使用哪种 LightRAG 查询模式来生成响应。支持的前缀包括:
|
||||
|
||||
```
|
||||
/local
|
||||
/global
|
||||
/hybrid
|
||||
/naive
|
||||
/mix
|
||||
/bypass
|
||||
```
|
||||
|
||||
例如,聊天消息 "/mix 唐僧有几个徒弟" 将触发 LightRAG 的混合模式查询。没有查询前缀的聊天消息默认会触发混合模式查询。
|
||||
|
||||
"/bypass" 不是 LightRAG 查询模式,它会告诉 API 服务器将查询连同聊天历史直接传递给底层 LLM。因此用户可以使用 LLM 基于聊天历史回答问题。如果您使用 Open WebUI 作为前端,您可以直接切换到普通 LLM 模型,而不是使用 /bypass 前缀。
|
||||
|
||||
## API 密钥和认证
|
||||
|
||||
默认情况下,LightRAG 服务器可以在没有任何认证的情况下访问。我们可以使用 API 密钥或账户凭证配置服务器以确保其安全。
|
||||
|
||||
* API 密钥
|
||||
|
||||
```
|
||||
LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||
WHITELIST_PATHS=/health,/api/*
|
||||
```
|
||||
|
||||
> 健康检查和 Ollama 模拟端点默认不进行 API 密钥检查。
|
||||
|
||||
* 账户凭证(Web 界面需要登录后才能访问)
|
||||
|
||||
LightRAG API 服务器使用基于 HS256 算法的 JWT 认证。要启用安全访问控制,需要以下环境变量:
|
||||
|
||||
```bash
|
||||
# JWT 认证
|
||||
AUTH_USERNAME=admin # 登录名
|
||||
AUTH_PASSWORD=admin123 # 密码
|
||||
TOKEN_SECRET=your-key # JWT 密钥
|
||||
TOKEN_EXPIRE_HOURS=4 # 过期时间
|
||||
```
|
||||
|
||||
> 目前仅支持配置一个管理员账户和密码。尚未开发和实现完整的账户系统。
|
||||
|
||||
如果未配置账户凭证,Web 界面将以访客身份访问系统。因此,即使仅配置了 API 密钥,所有 API 仍然可以通过访客账户访问,这仍然不安全。因此,要保护 API,需要同时配置这两种认证方法。
|
||||
|
||||
## Azure OpenAI 后端配置
|
||||
|
||||
可以使用以下 Azure CLI 命令创建 Azure OpenAI API(您需要先从 [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli) 安装 Azure CLI):
|
||||
|
||||
```bash
|
||||
# 根据需要更改资源组名称、位置和 OpenAI 资源名称
|
||||
RESOURCE_GROUP_NAME=LightRAG
|
||||
LOCATION=swedencentral
|
||||
RESOURCE_NAME=LightRAG-OpenAI
|
||||
|
||||
az login
|
||||
az group create --name $RESOURCE_GROUP_NAME --location $LOCATION
|
||||
az cognitiveservices account create --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME --kind OpenAI --sku S0 --location swedencentral
|
||||
az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME --model-format OpenAI --name $RESOURCE_NAME --deployment-name gpt-4o --model-name gpt-4o --model-version "2024-08-06" --sku-capacity 100 --sku-name "Standard"
|
||||
az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME --model-format OpenAI --name $RESOURCE_NAME --deployment-name text-embedding-3-large --model-name text-embedding-3-large --model-version "1" --sku-capacity 80 --sku-name "Standard"
|
||||
az cognitiveservices account show --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME --query "properties.endpoint"
|
||||
az cognitiveservices account keys list --name $RESOURCE_NAME -g $RESOURCE_GROUP_NAME
|
||||
```
|
||||
|
||||
最后一个命令的输出将提供 OpenAI API 的端点和密钥。您可以使用这些值在 `.env` 文件中设置环境变量。
|
||||
|
||||
```
|
||||
# .env 中的 Azure OpenAI 配置
|
||||
LLM_BINDING=azure_openai
|
||||
LLM_BINDING_HOST=your-azure-endpoint
|
||||
LLM_MODEL=your-model-deployment-name
|
||||
LLM_BINDING_API_KEY=your-azure-api-key
|
||||
AZURE_OPENAI_API_VERSION=2024-08-01-preview # 可选,默认为最新版本
|
||||
EMBEDDING_BINDING=azure_openai # 如果使用 Azure OpenAI 进行嵌入
|
||||
EMBEDDING_MODEL=your-embedding-deployment-name
|
||||
```
|
||||
|
||||
## LightRAG 服务器详细配置
|
||||
|
||||
API 服务器可以通过三种方式配置(优先级从高到低):
|
||||
|
||||
* 命令行参数
|
||||
* 环境变量或 .env 文件
|
||||
* Config.ini(仅用于存储配置)
|
||||
|
||||
大多数配置都有默认设置,详细信息请查看示例文件:`.env.example`。数据存储配置也可以通过 config.ini 设置。为方便起见,提供了示例文件 `config.ini.example`。
|
||||
|
||||
### 支持的 LLM 和嵌入后端
|
||||
|
||||
LightRAG 支持绑定到各种 LLM/嵌入后端:
|
||||
|
||||
* ollama
|
||||
* lollms
|
||||
* openai 和 openai 兼容
|
||||
* azure_openai
|
||||
|
||||
使用环境变量 `LLM_BINDING` 或 CLI 参数 `--llm-binding` 选择 LLM 后端类型。使用环境变量 `EMBEDDING_BINDING` 或 CLI 参数 `--embedding-binding` 选择嵌入后端类型。
|
||||
|
||||
### 实体提取配置
|
||||
* ENABLE_LLM_CACHE_FOR_EXTRACT:为实体提取启用 LLM 缓存(默认:true)
|
||||
|
||||
在测试环境中将 `ENABLE_LLM_CACHE_FOR_EXTRACT` 设置为 true 以减少 LLM 调用成本是很常见的做法。
|
||||
|
||||
### 支持的存储类型
|
||||
|
||||
LightRAG 使用 4 种类型的存储用于不同目的:
|
||||
|
||||
* KV_STORAGE:llm 响应缓存、文本块、文档信息
|
||||
* VECTOR_STORAGE:实体向量、关系向量、块向量
|
||||
* GRAPH_STORAGE:实体关系图
|
||||
* DOC_STATUS_STORAGE:文档索引状态
|
||||
|
||||
每种存储类型都有几种实现:
|
||||
|
||||
* KV_STORAGE 支持的实现名称
|
||||
|
||||
```
|
||||
JsonKVStorage JsonFile(默认)
|
||||
MongoKVStorage MogonDB
|
||||
RedisKVStorage Redis
|
||||
TiDBKVStorage TiDB
|
||||
PGKVStorage Postgres
|
||||
OracleKVStorage Oracle
|
||||
```
|
||||
|
||||
* GRAPH_STORAGE 支持的实现名称
|
||||
|
||||
```
|
||||
NetworkXStorage NetworkX(默认)
|
||||
Neo4JStorage Neo4J
|
||||
MongoGraphStorage MongoDB
|
||||
TiDBGraphStorage TiDB
|
||||
AGEStorage AGE
|
||||
GremlinStorage Gremlin
|
||||
PGGraphStorage Postgres
|
||||
OracleGraphStorage Postgres
|
||||
```
|
||||
|
||||
* VECTOR_STORAGE 支持的实现名称
|
||||
|
||||
```
|
||||
NanoVectorDBStorage NanoVector(默认)
|
||||
MilvusVectorDBStorge Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
FaissVectorDBStorage Faiss
|
||||
QdrantVectorDBStorage Qdrant
|
||||
OracleVectorDBStorage Oracle
|
||||
MongoVectorDBStorage MongoDB
|
||||
```
|
||||
|
||||
* DOC_STATUS_STORAGE 支持的实现名称
|
||||
|
||||
```
|
||||
JsonDocStatusStorage JsonFile(默认)
|
||||
PGDocStatusStorage Postgres
|
||||
MongoDocStatusStorage MongoDB
|
||||
```
|
||||
|
||||
### 如何选择存储实现
|
||||
|
||||
您可以通过环境变量选择存储实现。在首次启动 API 服务器之前,您可以将以下环境变量设置为特定的存储实现名称:
|
||||
|
||||
```
|
||||
LIGHTRAG_KV_STORAGE=PGKVStorage
|
||||
LIGHTRAG_VECTOR_STORAGE=PGVectorStorage
|
||||
LIGHTRAG_GRAPH_STORAGE=PGGraphStorage
|
||||
LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
|
||||
```
|
||||
|
||||
在向 LightRAG 添加文档后,您不能更改存储实现选择。目前尚不支持从一个存储实现迁移到另一个存储实现。更多信息请阅读示例 env 文件或 config.ini 文件。
|
||||
|
||||
### LightRag API 服务器命令行选项
|
||||
|
||||
| 参数 | 默认值 | 描述 |
|
||||
|-----------|---------|-------------|
|
||||
| --host | 0.0.0.0 | 服务器主机 |
|
||||
| --port | 9621 | 服务器端口 |
|
||||
| --working-dir | ./rag_storage | RAG 存储的工作目录 |
|
||||
| --input-dir | ./inputs | 包含输入文档的目录 |
|
||||
| --max-async | 4 | 最大异步操作数 |
|
||||
| --max-tokens | 32768 | 最大 token 大小 |
|
||||
| --timeout | 150 | 超时时间(秒)。None 表示无限超时(不推荐) |
|
||||
| --log-level | INFO | 日志级别(DEBUG、INFO、WARNING、ERROR、CRITICAL) |
|
||||
| --verbose | - | 详细调试输出(True、False) |
|
||||
| --key | None | 用于认证的 API 密钥。保护 lightrag 服务器免受未授权访问 |
|
||||
| --ssl | False | 启用 HTTPS |
|
||||
| --ssl-certfile | None | SSL 证书文件路径(如果启用 --ssl 则必需) |
|
||||
| --ssl-keyfile | None | SSL 私钥文件路径(如果启用 --ssl 则必需) |
|
||||
| --top-k | 50 | 要检索的 top-k 项目数;在"local"模式下对应实体,在"global"模式下对应关系。 |
|
||||
| --cosine-threshold | 0.4 | 节点和关系检索的余弦阈值,与 top-k 一起控制节点和关系的检索。 |
|
||||
| --llm-binding | ollama | LLM 绑定类型(lollms、ollama、openai、openai-ollama、azure_openai) |
|
||||
| --embedding-binding | ollama | 嵌入绑定类型(lollms、ollama、openai、azure_openai) |
|
||||
| auto-scan-at-startup | - | 扫描输入目录中的新文件并开始索引 |
|
||||
|
||||
### 使用示例
|
||||
|
||||
#### 使用 ollama 默认本地服务器作为 llm 和嵌入后端运行 Lightrag 服务器
|
||||
|
||||
Ollama 是 llm 和嵌入的默认后端,因此默认情况下您可以不带参数运行 lightrag-server,将使用默认值。确保已安装 ollama 并且正在运行,且默认模型已安装在 ollama 上。
|
||||
|
||||
```bash
|
||||
# 使用 ollama 运行 lightrag,llm 使用 mistral-nemo:latest,嵌入使用 bge-m3:latest
|
||||
lightrag-server
|
||||
|
||||
# 使用认证密钥
|
||||
lightrag-server --key my-key
|
||||
```
|
||||
|
||||
#### 使用 lollms 默认本地服务器作为 llm 和嵌入后端运行 Lightrag 服务器
|
||||
|
||||
```bash
|
||||
# 使用 lollms 运行 lightrag,llm 使用 mistral-nemo:latest,嵌入使用 bge-m3:latest
|
||||
# 在 .env 或 config.ini 中配置 LLM_BINDING=lollms 和 EMBEDDING_BINDING=lollms
|
||||
lightrag-server
|
||||
|
||||
# 使用认证密钥
|
||||
lightrag-server --key my-key
|
||||
```
|
||||
|
||||
#### 使用 openai 服务器作为 llm 和嵌入后端运行 Lightrag 服务器
|
||||
|
||||
```bash
|
||||
# 使用 openai 运行 lightrag,llm 使用 GPT-4o-mini,嵌入使用 text-embedding-3-small
|
||||
# 在 .env 或 config.ini 中配置:
|
||||
# LLM_BINDING=openai
|
||||
# LLM_MODEL=GPT-4o-mini
|
||||
# EMBEDDING_BINDING=openai
|
||||
# EMBEDDING_MODEL=text-embedding-3-small
|
||||
lightrag-server
|
||||
|
||||
# 使用认证密钥
|
||||
lightrag-server --key my-key
|
||||
```
|
||||
|
||||
#### 使用 azure openai 服务器作为 llm 和嵌入后端运行 Lightrag 服务器
|
||||
|
||||
```bash
|
||||
# 使用 azure_openai 运行 lightrag
|
||||
# 在 .env 或 config.ini 中配置:
|
||||
# LLM_BINDING=azure_openai
|
||||
# LLM_MODEL=your-model
|
||||
# EMBEDDING_BINDING=azure_openai
|
||||
# EMBEDDING_MODEL=your-embedding-model
|
||||
lightrag-server
|
||||
|
||||
# 使用认证密钥
|
||||
lightrag-server --key my-key
|
||||
```
|
||||
|
||||
**重要说明:**
|
||||
- 对于 LoLLMs:确保指定的模型已安装在您的 LoLLMs 实例中
|
||||
- 对于 Ollama:确保指定的模型已安装在您的 Ollama 实例中
|
||||
- 对于 OpenAI:确保您已设置 OPENAI_API_KEY 环境变量
|
||||
- 对于 Azure OpenAI:按照先决条件部分所述构建和配置您的服务器
|
||||
|
||||
要获取任何服务器的帮助,使用 --help 标志:
|
||||
```bash
|
||||
lightrag-server --help
|
||||
```
|
||||
|
||||
注意:如果您不需要 API 功能,可以使用以下命令安装不带 API 支持的基本包:
|
||||
```bash
|
||||
pip install lightrag-hku
|
||||
```
|
||||
|
||||
## API 端点
|
||||
|
||||
所有服务器(LoLLMs、Ollama、OpenAI 和 Azure OpenAI)都为 RAG 功能提供相同的 REST API 端点。当 API 服务器运行时,访问:
|
||||
|
||||
- Swagger UI:http://localhost:9621/docs
|
||||
- ReDoc:http://localhost:9621/redoc
|
||||
|
||||
您可以使用提供的 curl 命令或通过 Swagger UI 界面测试 API 端点。确保:
|
||||
|
||||
1. 启动适当的后端服务(LoLLMs、Ollama 或 OpenAI)
|
||||
2. 启动 RAG 服务器
|
||||
3. 使用文档管理端点上传一些文档
|
||||
4. 使用查询端点查询系统
|
||||
5. 如果在输入目录中放入新文件,触发文档扫描
|
||||
|
||||
### 查询端点
|
||||
|
||||
#### POST /query
|
||||
使用不同搜索模式查询 RAG 系统。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9621/query" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query": "您的问题", "mode": "hybrid", ""}'
|
||||
```
|
||||
|
||||
#### POST /query/stream
|
||||
从 RAG 系统流式获取响应。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9621/query/stream" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query": "您的问题", "mode": "hybrid"}'
|
||||
```
|
||||
|
||||
### 文档管理端点
|
||||
|
||||
#### POST /documents/text
|
||||
直接将文本插入 RAG 系统。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9621/documents/text" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text": "您的文本内容", "description": "可选描述"}'
|
||||
```
|
||||
|
||||
#### POST /documents/file
|
||||
向 RAG 系统上传单个文件。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9621/documents/file" \
|
||||
-F "file=@/path/to/your/document.txt" \
|
||||
-F "description=可选描述"
|
||||
```
|
||||
|
||||
#### POST /documents/batch
|
||||
一次上传多个文件。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9621/documents/batch" \
|
||||
-F "files=@/path/to/doc1.txt" \
|
||||
-F "files=@/path/to/doc2.txt"
|
||||
```
|
||||
|
||||
#### POST /documents/scan
|
||||
|
||||
触发输入目录中新文件的文档扫描。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9621/documents/scan" --max-time 1800
|
||||
```
|
||||
|
||||
> 根据所有新文件的预计索引时间调整 max-time。
|
||||
|
||||
#### DELETE /documents
|
||||
|
||||
从 RAG 系统中清除所有文档。
|
||||
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:9621/documents"
|
||||
```
|
||||
|
||||
### Ollama 模拟端点
|
||||
|
||||
#### GET /api/version
|
||||
|
||||
获取 Ollama 版本信息。
|
||||
|
||||
```bash
|
||||
curl http://localhost:9621/api/version
|
||||
```
|
||||
|
||||
#### GET /api/tags
|
||||
|
||||
获取 Ollama 可用模型。
|
||||
|
||||
```bash
|
||||
curl http://localhost:9621/api/tags
|
||||
```
|
||||
|
||||
#### POST /api/chat
|
||||
|
||||
处理聊天补全请求。通过根据查询前缀选择查询模式将用户查询路由到 LightRAG。检测并将 OpenWebUI 会话相关请求(用于元数据生成任务)直接转发给底层 LLM。
|
||||
|
||||
```shell
|
||||
curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/json" -d \
|
||||
'{"model":"lightrag:latest","messages":[{"role":"user","content":"猪八戒是谁"}],"stream":true}'
|
||||
```
|
||||
|
||||
> 有关 Ollama API 的更多信息,请访问:[Ollama API 文档](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
||||
|
||||
#### POST /api/generate
|
||||
|
||||
处理生成补全请求。为了兼容性目的,该请求不由 LightRAG 处理,而是由底层 LLM 模型处理。
|
||||
|
||||
### 实用工具端点
|
||||
|
||||
#### GET /health
|
||||
检查服务器健康状况和配置。
|
||||
|
||||
```bash
|
||||
curl "http://localhost:9621/health"
|
||||
|
||||
```
|
@@ -153,10 +153,6 @@ sudo systemctl status lightrag.service
|
||||
sudo systemctl enable lightrag.service
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Ollama Emulation
|
||||
|
||||
We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily.
|
||||
@@ -196,8 +192,11 @@ By default, the LightRAG Server can be accessed without any authentication. We c
|
||||
|
||||
```
|
||||
LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||
WHITELIST_PATHS=/health,/api/*
|
||||
```
|
||||
|
||||
> Health check and Ollama emuluation endpoins is exclude from API-KEY check by default.
|
||||
|
||||
* Account credentials (the web UI requires login before access)
|
||||
|
||||
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
|
||||
@@ -317,7 +316,7 @@ OracleGraphStorage Postgres
|
||||
|
||||
```
|
||||
NanoVectorDBStorage NanoVector(default)
|
||||
MilvusVectorDBStorge Milvus
|
||||
MilvusVectorDBStorage Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
|
@@ -1 +1 @@
|
||||
__api_version__ = "1.2.2"
|
||||
__api_version__ = "1.2.5"
|
||||
|
@@ -18,7 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.api.utils_api import (
|
||||
get_api_key_dependency,
|
||||
get_combined_auth_dependency,
|
||||
parse_args,
|
||||
get_default_host,
|
||||
display_splash_screen,
|
||||
@@ -41,7 +41,6 @@ from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_pipeline_status_lock,
|
||||
initialize_pipeline_status,
|
||||
get_all_update_flags_status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from .auth import auth_handler
|
||||
@@ -136,19 +135,28 @@ def create_app(args):
|
||||
await rag.finalize_storages()
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
title="LightRAG API",
|
||||
description="API for querying text using LightRAG with separate storage and input directories"
|
||||
app_kwargs = {
|
||||
"title": "LightRAG Server API",
|
||||
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
||||
+ "(With authentication)"
|
||||
if api_key
|
||||
else "",
|
||||
version=__api_version__,
|
||||
openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL
|
||||
docs_url="/docs", # Explicitly set docs URL
|
||||
redoc_url="/redoc", # Explicitly set redoc URL
|
||||
openapi_tags=[{"name": "api"}],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
"version": __api_version__,
|
||||
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
||||
"docs_url": "/docs", # Explicitly set docs URL
|
||||
"redoc_url": "/redoc", # Explicitly set redoc URL
|
||||
"openapi_tags": [{"name": "api"}],
|
||||
"lifespan": lifespan,
|
||||
}
|
||||
|
||||
# Configure Swagger UI parameters
|
||||
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
||||
app_kwargs["swagger_ui_parameters"] = {
|
||||
"persistAuthorization": True,
|
||||
"tryItOutEnabled": True,
|
||||
}
|
||||
|
||||
app = FastAPI(**app_kwargs)
|
||||
|
||||
def get_cors_origins():
|
||||
"""Get allowed origins from environment variable
|
||||
@@ -168,8 +176,8 @@ def create_app(args):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create the optional API key dependency
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
# Create combined auth dependency for all endpoints
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
@@ -200,6 +208,7 @@ def create_app(args):
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
kwargs["temperature"] = args.temperature
|
||||
return await openai_complete_if_cache(
|
||||
args.llm_model,
|
||||
prompt,
|
||||
@@ -222,6 +231,7 @@ def create_app(args):
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
kwargs["temperature"] = args.temperature
|
||||
return await azure_openai_complete_if_cache(
|
||||
args.llm_model,
|
||||
prompt,
|
||||
@@ -302,6 +312,7 @@ def create_app(args):
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
)
|
||||
else: # azure_openai
|
||||
rag = LightRAG(
|
||||
@@ -331,6 +342,7 @@ def create_app(args):
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
)
|
||||
|
||||
# Add routes
|
||||
@@ -339,7 +351,7 @@ def create_app(args):
|
||||
app.include_router(create_graph_routes(rag, api_key))
|
||||
|
||||
# Add Ollama API routes
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
||||
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
|
||||
app.include_router(ollama_api.router, prefix="/api")
|
||||
|
||||
@app.get("/")
|
||||
@@ -347,7 +359,7 @@ def create_app(args):
|
||||
"""Redirect root path to /webui"""
|
||||
return RedirectResponse(url="/webui")
|
||||
|
||||
@app.get("/auth-status", dependencies=[Depends(optional_api_key)])
|
||||
@app.get("/auth-status")
|
||||
async def get_auth_status():
|
||||
"""Get authentication status and guest token if auth is not configured"""
|
||||
|
||||
@@ -373,7 +385,7 @@ def create_app(args):
|
||||
"api_version": __api_version__,
|
||||
}
|
||||
|
||||
@app.post("/login", dependencies=[Depends(optional_api_key)])
|
||||
@app.post("/login")
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
if not auth_handler.accounts:
|
||||
# Authentication not configured, return guest token
|
||||
@@ -406,12 +418,9 @@ def create_app(args):
|
||||
"api_version": __api_version__,
|
||||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
# Get update flags status for all namespaces
|
||||
update_status = await get_all_update_flags_status()
|
||||
|
||||
username = os.getenv("AUTH_USERNAME")
|
||||
password = os.getenv("AUTH_PASSWORD")
|
||||
if not (username and password):
|
||||
@@ -439,7 +448,6 @@ def create_app(args):
|
||||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
},
|
||||
"update_status": update_status,
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"auth_mode": auth_mode,
|
||||
|
@@ -17,15 +17,13 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.api.utils_api import (
|
||||
get_api_key_dependency,
|
||||
get_combined_auth_dependency,
|
||||
global_args,
|
||||
get_auth_dependency,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
tags=["documents"],
|
||||
dependencies=[Depends(get_auth_dependency())],
|
||||
)
|
||||
|
||||
# Temporary file prefix
|
||||
@@ -113,6 +111,7 @@ class PipelineStatusResponse(BaseModel):
|
||||
request_pending: Flag for pending request for processing
|
||||
latest_message: Latest message from pipeline processing
|
||||
history_messages: List of history messages
|
||||
update_status: Status of update flags for all namespaces
|
||||
"""
|
||||
|
||||
autoscanned: bool = False
|
||||
@@ -125,6 +124,7 @@ class PipelineStatusResponse(BaseModel):
|
||||
request_pending: bool = False
|
||||
latest_message: str = ""
|
||||
history_messages: Optional[List[str]] = None
|
||||
update_status: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow additional fields from the pipeline status
|
||||
@@ -475,8 +475,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
if not new_files:
|
||||
return
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from global_args
|
||||
max_parallel = global_args["max_parallel_insert"]
|
||||
# Get MAX_PARALLEL_INSERT from global_args["main_args"]
|
||||
max_parallel = global_args["main_args"].max_parallel_insert
|
||||
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
|
||||
batch_size = 2 * max_parallel
|
||||
|
||||
@@ -505,9 +505,10 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
def create_document_routes(
|
||||
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||
):
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
# Create combined auth dependency for document routes
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.post("/scan", dependencies=[Depends(optional_api_key)])
|
||||
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
||||
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Trigger the scanning process for new documents.
|
||||
@@ -523,7 +524,7 @@ def create_document_routes(
|
||||
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
||||
return {"status": "scanning_started"}
|
||||
|
||||
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
||||
@router.post("/upload", dependencies=[Depends(combined_auth)])
|
||||
async def upload_to_input_dir(
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
):
|
||||
@@ -568,7 +569,7 @@ def create_document_routes(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
"/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||
"/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def insert_text(
|
||||
request: InsertTextRequest, background_tasks: BackgroundTasks
|
||||
@@ -603,7 +604,7 @@ def create_document_routes(
|
||||
@router.post(
|
||||
"/texts",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
dependencies=[Depends(combined_auth)],
|
||||
)
|
||||
async def insert_texts(
|
||||
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
||||
@@ -636,7 +637,7 @@ def create_document_routes(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
"/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def insert_file(
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
@@ -681,7 +682,7 @@ def create_document_routes(
|
||||
@router.post(
|
||||
"/file_batch",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
dependencies=[Depends(combined_auth)],
|
||||
)
|
||||
async def insert_batch(
|
||||
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
|
||||
@@ -742,7 +743,7 @@ def create_document_routes(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete(
|
||||
"", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||
"", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def clear_documents():
|
||||
"""
|
||||
@@ -771,7 +772,7 @@ def create_document_routes(
|
||||
|
||||
@router.get(
|
||||
"/pipeline_status",
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
dependencies=[Depends(combined_auth)],
|
||||
response_model=PipelineStatusResponse,
|
||||
)
|
||||
async def get_pipeline_status() -> PipelineStatusResponse:
|
||||
@@ -798,13 +799,34 @@ def create_document_routes(
|
||||
HTTPException: If an error occurs while retrieving pipeline status (500)
|
||||
"""
|
||||
try:
|
||||
from lightrag.kg.shared_storage import get_namespace_data
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_all_update_flags_status,
|
||||
)
|
||||
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
|
||||
# Get update flags status for all namespaces
|
||||
update_status = await get_all_update_flags_status()
|
||||
|
||||
# Convert MutableBoolean objects to regular boolean values
|
||||
processed_update_status = {}
|
||||
for namespace, flags in update_status.items():
|
||||
processed_flags = []
|
||||
for flag in flags:
|
||||
# Handle both multiprocess and single process cases
|
||||
if hasattr(flag, "value"):
|
||||
processed_flags.append(bool(flag.value))
|
||||
else:
|
||||
processed_flags.append(bool(flag))
|
||||
processed_update_status[namespace] = processed_flags
|
||||
|
||||
# Convert to regular dict if it's a Manager.dict
|
||||
status_dict = dict(pipeline_status)
|
||||
|
||||
# Add processed update_status to the status dictionary
|
||||
status_dict["update_status"] = processed_update_status
|
||||
|
||||
# Convert history_messages to a regular list if it's a Manager.list
|
||||
if "history_messages" in status_dict:
|
||||
status_dict["history_messages"] = list(status_dict["history_messages"])
|
||||
@@ -819,7 +841,7 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("", dependencies=[Depends(optional_api_key)])
|
||||
@router.get("", dependencies=[Depends(combined_auth)])
|
||||
async def documents() -> DocsStatusesResponse:
|
||||
"""
|
||||
Get the status of all documents in the system.
|
||||
|
@@ -5,15 +5,15 @@ This module contains all graph-related routes for the LightRAG API.
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||
from ..utils_api import get_combined_auth_dependency
|
||||
|
||||
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
|
||||
router = APIRouter(tags=["graph"])
|
||||
|
||||
|
||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
|
||||
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
|
||||
async def get_graph_labels():
|
||||
"""
|
||||
Get all graph labels
|
||||
@@ -23,7 +23,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
"""
|
||||
return await rag.get_graph_labels()
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
||||
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
||||
async def get_knowledge_graph(
|
||||
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||
):
|
||||
|
@@ -11,7 +11,8 @@ import asyncio
|
||||
from ascii_colors import trace_exception
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import encode_string_by_tiktoken
|
||||
from lightrag.api.utils_api import ollama_server_infos
|
||||
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
||||
from fastapi import Depends
|
||||
|
||||
|
||||
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
||||
@@ -122,20 +123,24 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||
|
||||
|
||||
class OllamaAPI:
|
||||
def __init__(self, rag: LightRAG, top_k: int = 60):
|
||||
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
|
||||
self.rag = rag
|
||||
self.ollama_server_infos = ollama_server_infos
|
||||
self.top_k = top_k
|
||||
self.api_key = api_key
|
||||
self.router = APIRouter(tags=["ollama"])
|
||||
self.setup_routes()
|
||||
|
||||
def setup_routes(self):
|
||||
@self.router.get("/version")
|
||||
# Create combined auth dependency for Ollama API routes
|
||||
combined_auth = get_combined_auth_dependency(self.api_key)
|
||||
|
||||
@self.router.get("/version", dependencies=[Depends(combined_auth)])
|
||||
async def get_version():
|
||||
"""Get Ollama version information"""
|
||||
return OllamaVersionResponse(version="0.5.4")
|
||||
|
||||
@self.router.get("/tags")
|
||||
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
|
||||
async def get_tags():
|
||||
"""Return available models acting as an Ollama server"""
|
||||
return OllamaTagResponse(
|
||||
@@ -158,7 +163,7 @@ class OllamaAPI:
|
||||
]
|
||||
)
|
||||
|
||||
@self.router.post("/generate")
|
||||
@self.router.post("/generate", dependencies=[Depends(combined_auth)])
|
||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||
"""Handle generate completion requests acting as an Ollama model
|
||||
For compatibility purpose, the request is not processed by LightRAG,
|
||||
@@ -324,7 +329,7 @@ class OllamaAPI:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@self.router.post("/chat")
|
||||
@self.router.post("/chat", dependencies=[Depends(combined_auth)])
|
||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||
"""Process chat completion requests acting as an Ollama model
|
||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
||||
|
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from lightrag.base import QueryParam
|
||||
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
||||
from ..utils_api import get_combined_auth_dependency
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from ascii_colors import trace_exception
|
||||
|
||||
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
|
||||
router = APIRouter(tags=["query"])
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
@@ -139,10 +139,10 @@ class QueryResponse(BaseModel):
|
||||
|
||||
|
||||
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.post(
|
||||
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
|
||||
"/query", response_model=QueryResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def query_text(request: QueryRequest):
|
||||
"""
|
||||
@@ -176,7 +176,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||
@router.post("/query/stream", dependencies=[Depends(combined_auth)])
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
"""
|
||||
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
|
||||
|
@@ -4,22 +4,42 @@ Utility functions for the LightRAG API.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Tuple
|
||||
import sys
|
||||
import logging
|
||||
from ascii_colors import ASCIIColors
|
||||
from lightrag.api import __api_version__
|
||||
from fastapi import HTTPException, Security, Depends, Request, status
|
||||
from fastapi import HTTPException, Security, Request, status
|
||||
from dotenv import load_dotenv
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from .auth import auth_handler
|
||||
from ..prompt import PROMPTS
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
global_args = {"main_args": None}
|
||||
|
||||
# Get whitelist paths from environment variable, only once during initialization
|
||||
default_whitelist = "/health,/api/*"
|
||||
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
|
||||
|
||||
# Pre-compile path matching patterns
|
||||
whitelist_patterns: List[Tuple[str, bool]] = []
|
||||
for path in whitelist_paths:
|
||||
path = path.strip()
|
||||
if path:
|
||||
# If path ends with /*, match all paths with that prefix
|
||||
if path.endswith("/*"):
|
||||
prefix = path[:-2]
|
||||
whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match)
|
||||
else:
|
||||
whitelist_patterns.append((path, False)) # (exact_path, is_prefix_match)
|
||||
|
||||
# Global authentication configuration
|
||||
auth_configured = bool(auth_handler.accounts)
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
@@ -34,47 +54,114 @@ class OllamaServerInfos:
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
|
||||
def get_auth_dependency():
|
||||
# Set default whitelist paths
|
||||
whitelist = os.getenv("WHITELIST_PATHS", "/login,/health").split(",")
|
||||
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
"""
|
||||
Create a combined authentication dependency that implements authentication logic
|
||||
based on API key, OAuth2 token, and whitelist paths.
|
||||
|
||||
async def dependency(
|
||||
Args:
|
||||
api_key (Optional[str]): API key for validation
|
||||
|
||||
Returns:
|
||||
Callable: A dependency function that implements the authentication logic
|
||||
"""
|
||||
# Use global whitelist_patterns and auth_configured variables
|
||||
# whitelist_patterns and auth_configured are already initialized at module level
|
||||
|
||||
# Only calculate api_key_configured as it depends on the function parameter
|
||||
api_key_configured = bool(api_key)
|
||||
|
||||
# Create security dependencies with proper descriptions for Swagger UI
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="login", auto_error=False, description="OAuth2 Password Authentication"
|
||||
)
|
||||
|
||||
# If API key is configured, create an API key header security
|
||||
api_key_header = None
|
||||
if api_key_configured:
|
||||
api_key_header = APIKeyHeader(
|
||||
name="X-API-Key", auto_error=False, description="API Key Authentication"
|
||||
)
|
||||
|
||||
async def combined_dependency(
|
||||
request: Request,
|
||||
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
|
||||
token: str = Security(oauth2_scheme),
|
||||
api_key_header_value: Optional[str] = None
|
||||
if api_key_header is None
|
||||
else Security(api_key_header),
|
||||
):
|
||||
# Check if authentication is configured
|
||||
auth_configured = bool(auth_handler.accounts)
|
||||
# 1. Check if path is in whitelist
|
||||
path = request.url.path
|
||||
for pattern, is_prefix in whitelist_patterns:
|
||||
if (is_prefix and path.startswith(pattern)) or (
|
||||
not is_prefix and path == pattern
|
||||
):
|
||||
return # Whitelist path, allow access
|
||||
|
||||
# If authentication is not configured, skip all validation
|
||||
if not auth_configured:
|
||||
return
|
||||
# 2. Validate token first if provided in the request (Ensure 401 error if token is invalid)
|
||||
if token:
|
||||
try:
|
||||
token_info = auth_handler.validate_token(token)
|
||||
# Accept guest token if no auth is configured
|
||||
if not auth_configured and token_info.get("role") == "guest":
|
||||
return
|
||||
# Accept non-guest token if auth is configured
|
||||
if auth_configured and token_info.get("role") != "guest":
|
||||
return
|
||||
|
||||
# For configured auth, allow whitelist paths without token
|
||||
if request.url.path in whitelist:
|
||||
return
|
||||
|
||||
# Require token for all other paths when auth is configured
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required"
|
||||
)
|
||||
|
||||
try:
|
||||
token_info = auth_handler.validate_token(token)
|
||||
# Reject guest tokens when authentication is configured
|
||||
if token_info.get("role") == "guest":
|
||||
# Token validation failed, immediately return 401 error
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required. Guest access not allowed when authentication is configured.",
|
||||
detail="Invalid token. Please login again.",
|
||||
)
|
||||
except Exception:
|
||||
except HTTPException as e:
|
||||
# If already a 401 error, re-raise it
|
||||
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
raise
|
||||
# For other exceptions, continue processing
|
||||
|
||||
# 3. Acept all request if no API protection needed
|
||||
if not auth_configured and not api_key_configured:
|
||||
return
|
||||
|
||||
# 4. Validate API key if provided and API-Key authentication is configured
|
||||
if (
|
||||
api_key_configured
|
||||
and api_key_header_value
|
||||
and api_key_header_value == api_key
|
||||
):
|
||||
return # API key validation successful
|
||||
|
||||
### Authentication failed ####
|
||||
|
||||
# if password authentication is configured but not provided, ensure 401 error if auth_configured
|
||||
if auth_configured and not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No credentials provided. Please login.",
|
||||
)
|
||||
|
||||
return
|
||||
# if api key is provided but validation failed
|
||||
if api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API Key",
|
||||
)
|
||||
|
||||
return dependency
|
||||
# if api_key_configured but not provided
|
||||
if api_key_configured and not api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="API Key required",
|
||||
)
|
||||
|
||||
# Otherwise: refuse access and return 403 error
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="API Key required or login authentication required.",
|
||||
)
|
||||
|
||||
return combined_dependency
|
||||
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
@@ -88,19 +175,37 @@ def get_api_key_dependency(api_key: Optional[str]):
|
||||
Returns:
|
||||
Callable: A dependency function that validates the API key.
|
||||
"""
|
||||
if not api_key:
|
||||
# Use global whitelist_patterns and auth_configured variables
|
||||
# whitelist_patterns and auth_configured are already initialized at module level
|
||||
|
||||
# Only calculate api_key_configured as it depends on the function parameter
|
||||
api_key_configured = bool(api_key)
|
||||
|
||||
if not api_key_configured:
|
||||
# If no API key is configured, return a dummy dependency that always succeeds
|
||||
async def no_auth():
|
||||
async def no_auth(request: Request = None, **kwargs):
|
||||
return None
|
||||
|
||||
return no_auth
|
||||
|
||||
# If API key is configured, use proper authentication
|
||||
# If API key is configured, use proper authentication with Security for Swagger UI
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def api_key_auth(
|
||||
api_key_header_value: Optional[str] = Security(api_key_header),
|
||||
request: Request,
|
||||
api_key_header_value: Optional[str] = Security(
|
||||
api_key_header, description="API Key for authentication"
|
||||
),
|
||||
):
|
||||
# Check if request path is in whitelist
|
||||
path = request.url.path
|
||||
for pattern, is_prefix in whitelist_patterns:
|
||||
if (is_prefix and path.startswith(pattern)) or (
|
||||
not is_prefix and path == pattern
|
||||
):
|
||||
return # Whitelist path, allow access
|
||||
|
||||
# Non-whitelist path, validate API key
|
||||
if not api_key_header_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
|
||||
@@ -364,7 +469,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
||||
)
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from environment
|
||||
global_args["max_parallel_insert"] = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
|
||||
# Handle openai-ollama special case
|
||||
if args.llm_binding == "openai-ollama":
|
||||
@@ -395,6 +500,9 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
||||
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
||||
)
|
||||
|
||||
# Inject LLM temperature configuration
|
||||
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
|
||||
@@ -462,6 +570,12 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.llm_binding_host}")
|
||||
ASCIIColors.white(" ├─ Model: ", end="")
|
||||
ASCIIColors.yellow(f"{args.llm_model}")
|
||||
ASCIIColors.white(" ├─ Temperature: ", end="")
|
||||
ASCIIColors.yellow(f"{args.temperature}")
|
||||
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_async}")
|
||||
ASCIIColors.white(" ├─ Max Tokens: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_tokens}")
|
||||
ASCIIColors.white(" └─ Timeout: ", end="")
|
||||
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
||||
|
||||
@@ -477,13 +591,12 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.embedding_dim}")
|
||||
|
||||
# RAG Configuration
|
||||
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
|
||||
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
||||
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_async}")
|
||||
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
||||
ASCIIColors.yellow(f"{summary_language}")
|
||||
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
||||
ASCIIColors.yellow(f"{global_args['max_parallel_insert']}")
|
||||
ASCIIColors.white(" ├─ Max Tokens: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_tokens}")
|
||||
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
||||
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_embed_tokens}")
|
||||
ASCIIColors.white(" ├─ Chunk Size: ", end="")
|
||||
|
1
lightrag/api/webui/assets/index-BcBS1RaQ.css
generated
1
lightrag/api/webui/assets/index-BcBS1RaQ.css
generated
File diff suppressed because one or more lines are too long
1
lightrag/api/webui/assets/index-BwFyYQzx.css
generated
Normal file
1
lightrag/api/webui/assets/index-BwFyYQzx.css
generated
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
4
lightrag/api/webui/index.html
generated
4
lightrag/api/webui/index.html
generated
@@ -8,8 +8,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lightrag</title>
|
||||
<script type="module" crossorigin src="/webui/assets/index-qXLILB5u.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-BcBS1RaQ.css">
|
||||
<script type="module" crossorigin src="/webui/assets/index-DJ53id6i.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-BwFyYQzx.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
@@ -19,7 +19,6 @@ from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,9 +72,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if (is_multiprocess and self.storage_updated.value) or (
|
||||
not is_multiprocess and self.storage_updated
|
||||
):
|
||||
if self.storage_updated.value:
|
||||
logger.info(
|
||||
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
|
||||
)
|
||||
@@ -83,10 +80,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
self.storage_updated.value = False
|
||||
return self._index
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
@@ -343,18 +337,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._id_to_meta = {}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
async with self._storage_lock:
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
async with self._storage_lock:
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
@@ -364,10 +359,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
self.storage_updated.value = False
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
|
||||
return False # Return error
|
||||
|
@@ -221,6 +221,7 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
created_at=doc.get("created_at"),
|
||||
updated_at=doc.get("updated_at"),
|
||||
chunks_count=doc.get("chunks_count", -1),
|
||||
file_path=doc.get("file_path", doc["_id"]),
|
||||
)
|
||||
for doc in result
|
||||
}
|
||||
|
@@ -20,7 +20,6 @@ from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,16 +56,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._storage_lock = get_storage_lock(enable_logging=False)
|
||||
|
||||
async def _get_client(self):
|
||||
"""Check if the storage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if (is_multiprocess and self.storage_updated.value) or (
|
||||
not is_multiprocess and self.storage_updated
|
||||
):
|
||||
if self.storage_updated.value:
|
||||
logger.info(
|
||||
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
|
||||
)
|
||||
@@ -76,10 +73,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
self.storage_updated.value = False
|
||||
|
||||
return self._client
|
||||
|
||||
@@ -206,19 +200,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
async def index_done_callback(self) -> bool:
|
||||
"""Save data to disk"""
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Storage for {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Storage for {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
@@ -228,10 +223,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
self.storage_updated.value = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving data for {self.namespace}: {e}")
|
||||
|
@@ -21,7 +21,6 @@ from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
@@ -110,9 +109,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if (is_multiprocess and self.storage_updated.value) or (
|
||||
not is_multiprocess and self.storage_updated
|
||||
):
|
||||
if self.storage_updated.value:
|
||||
logger.info(
|
||||
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
|
||||
)
|
||||
@@ -121,10 +118,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
)
|
||||
# Reset update flag
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
self.storage_updated.value = False
|
||||
|
||||
return self._graph
|
||||
|
||||
@@ -401,18 +395,19 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
async def index_done_callback(self) -> bool:
|
||||
"""Save data to disk"""
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Graph for {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._graph = (
|
||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Graph for {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._graph = (
|
||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
@@ -422,10 +417,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
self.storage_updated.value = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving graph for {self.namespace}: {e}")
|
||||
|
@@ -24,7 +24,7 @@ def direct_log(message, level="INFO", enable_output: bool = True):
|
||||
T = TypeVar("T")
|
||||
LockType = Union[ProcessLock, asyncio.Lock]
|
||||
|
||||
is_multiprocess = None
|
||||
_is_multiprocess = None
|
||||
_workers = None
|
||||
_manager = None
|
||||
_initialized = None
|
||||
@@ -218,10 +218,10 @@ class UnifiedLock(Generic[T]):
|
||||
|
||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
async_lock = _async_locks.get("internal_lock") if is_multiprocess else None
|
||||
async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_internal_lock,
|
||||
is_async=not is_multiprocess,
|
||||
is_async=not _is_multiprocess,
|
||||
name="internal_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
@@ -230,10 +230,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
|
||||
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
async_lock = _async_locks.get("storage_lock") if is_multiprocess else None
|
||||
async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_storage_lock,
|
||||
is_async=not is_multiprocess,
|
||||
is_async=not _is_multiprocess,
|
||||
name="storage_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
@@ -242,10 +242,10 @@ def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
|
||||
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
async_lock = _async_locks.get("pipeline_status_lock") if is_multiprocess else None
|
||||
async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_pipeline_status_lock,
|
||||
is_async=not is_multiprocess,
|
||||
is_async=not _is_multiprocess,
|
||||
name="pipeline_status_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
@@ -254,10 +254,10 @@ def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
|
||||
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified graph database lock for ensuring atomic operations"""
|
||||
async_lock = _async_locks.get("graph_db_lock") if is_multiprocess else None
|
||||
async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_graph_db_lock,
|
||||
is_async=not is_multiprocess,
|
||||
is_async=not _is_multiprocess,
|
||||
name="graph_db_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
@@ -266,10 +266,10 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
|
||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||
async_lock = _async_locks.get("data_init_lock") if is_multiprocess else None
|
||||
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_data_init_lock,
|
||||
is_async=not is_multiprocess,
|
||||
is_async=not _is_multiprocess,
|
||||
name="data_init_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
@@ -297,7 +297,7 @@ def initialize_share_data(workers: int = 1):
|
||||
global \
|
||||
_manager, \
|
||||
_workers, \
|
||||
is_multiprocess, \
|
||||
_is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
@@ -312,14 +312,14 @@ def initialize_share_data(workers: int = 1):
|
||||
# Check if already initialized
|
||||
if _initialized:
|
||||
direct_log(
|
||||
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})"
|
||||
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={_is_multiprocess})"
|
||||
)
|
||||
return
|
||||
|
||||
_workers = workers
|
||||
|
||||
if workers > 1:
|
||||
is_multiprocess = True
|
||||
_is_multiprocess = True
|
||||
_manager = Manager()
|
||||
_internal_lock = _manager.Lock()
|
||||
_storage_lock = _manager.Lock()
|
||||
@@ -343,7 +343,7 @@ def initialize_share_data(workers: int = 1):
|
||||
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
|
||||
)
|
||||
else:
|
||||
is_multiprocess = False
|
||||
_is_multiprocess = False
|
||||
_internal_lock = asyncio.Lock()
|
||||
_storage_lock = asyncio.Lock()
|
||||
_pipeline_status_lock = asyncio.Lock()
|
||||
@@ -372,7 +372,7 @@ async def initialize_pipeline_status():
|
||||
return
|
||||
|
||||
# Create a shared list object for history_messages
|
||||
history_messages = _manager.list() if is_multiprocess else []
|
||||
history_messages = _manager.list() if _is_multiprocess else []
|
||||
pipeline_namespace.update(
|
||||
{
|
||||
"autoscanned": False, # Auto-scan started
|
||||
@@ -401,7 +401,7 @@ async def get_update_flag(namespace: str):
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
if is_multiprocess and _manager is not None:
|
||||
if _is_multiprocess and _manager is not None:
|
||||
_update_flags[namespace] = _manager.list()
|
||||
else:
|
||||
_update_flags[namespace] = []
|
||||
@@ -409,7 +409,7 @@ async def get_update_flag(namespace: str):
|
||||
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
|
||||
)
|
||||
|
||||
if is_multiprocess and _manager is not None:
|
||||
if _is_multiprocess and _manager is not None:
|
||||
new_update_flag = _manager.Value("b", False)
|
||||
else:
|
||||
# Create a simple mutable object to store boolean value for compatibility with mutiprocess
|
||||
@@ -434,11 +434,7 @@ async def set_all_update_flags(namespace: str):
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = True
|
||||
else:
|
||||
# Use .value attribute instead of direct assignment
|
||||
_update_flags[namespace][i].value = True
|
||||
_update_flags[namespace][i].value = True
|
||||
|
||||
|
||||
async def clear_all_update_flags(namespace: str):
|
||||
@@ -452,11 +448,7 @@ async def clear_all_update_flags(namespace: str):
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = False
|
||||
else:
|
||||
# Use .value attribute instead of direct assignment
|
||||
_update_flags[namespace][i].value = False
|
||||
_update_flags[namespace][i].value = False
|
||||
|
||||
|
||||
async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
@@ -474,7 +466,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
for namespace, flags in _update_flags.items():
|
||||
worker_statuses = []
|
||||
for flag in flags:
|
||||
if is_multiprocess:
|
||||
if _is_multiprocess:
|
||||
worker_statuses.append(flag.value)
|
||||
else:
|
||||
worker_statuses.append(flag)
|
||||
@@ -518,7 +510,7 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _shared_dicts:
|
||||
if is_multiprocess and _manager is not None:
|
||||
if _is_multiprocess and _manager is not None:
|
||||
_shared_dicts[namespace] = _manager.dict()
|
||||
else:
|
||||
_shared_dicts[namespace] = {}
|
||||
@@ -538,7 +530,7 @@ def finalize_share_data():
|
||||
"""
|
||||
global \
|
||||
_manager, \
|
||||
is_multiprocess, \
|
||||
_is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
@@ -558,11 +550,11 @@ def finalize_share_data():
|
||||
return
|
||||
|
||||
direct_log(
|
||||
f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})"
|
||||
f"Process {os.getpid()} finalizing storage data (multiprocess={_is_multiprocess})"
|
||||
)
|
||||
|
||||
# In multi-process mode, shut down the Manager
|
||||
if is_multiprocess and _manager is not None:
|
||||
if _is_multiprocess and _manager is not None:
|
||||
try:
|
||||
# Clear shared resources before shutting down Manager
|
||||
if _shared_dicts is not None:
|
||||
@@ -604,7 +596,7 @@ def finalize_share_data():
|
||||
# Reset global variables
|
||||
_manager = None
|
||||
_initialized = None
|
||||
is_multiprocess = None
|
||||
_is_multiprocess = None
|
||||
_shared_dicts = None
|
||||
_init_flags = None
|
||||
_storage_lock = None
|
||||
|
Reference in New Issue
Block a user