Merge branch 'main' into main

This commit is contained in:
Alex Z
2025-04-05 15:27:59 -07:00
committed by GitHub
77 changed files with 5920 additions and 5192 deletions

View File

@@ -11,7 +11,6 @@
- [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。 - [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
- [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。 - [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
- [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。 - [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
- [X] [2024.11.12]🎯📢LightRAG现在支持[Oracle Database 23ai的所有存储类型KV、向量和图](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py)。
- [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。 - [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
- [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app)允许您插入、查询、可视化和下载LightRAG知识。 - [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app)允许您插入、查询、可视化和下载LightRAG知识。
- [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。 - [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
@@ -410,6 +409,54 @@ if __name__ == "__main__":
</details> </details>
### Token统计功能
<details>
<summary> <b>概述和使用</b> </summary>
LightRAG提供了TokenTracker工具来跟踪和管理大模型的token消耗。这个功能对于控制API成本和优化性能特别有用。
#### 使用方法
```python
from lightrag.utils import TokenTracker
# 创建TokenTracker实例
token_tracker = TokenTracker()
# 方法1使用上下文管理器推荐
# 适用于需要自动跟踪token使用的场景
with token_tracker:
result1 = await llm_model_func("你的问题1")
result2 = await llm_model_func("你的问题2")
# 方法2手动添加token使用记录
# 适用于需要更精细控制token统计的场景
token_tracker.reset()
rag.insert()
rag.query("你的问题1", param=QueryParam(mode="naive"))
rag.query("你的问题2", param=QueryParam(mode="mix"))
# 显示总token使用量包含插入和查询操作
print("Token usage:", token_tracker.get_usage())
```
#### 使用建议
- 在长会话或批量操作中使用上下文管理器可以自动跟踪所有token消耗
- 对于需要分段统计的场景使用手动模式并适时调用reset()
- 定期检查token使用情况有助于及时发现异常消耗
- 在开发测试阶段积极使用此功能,以便优化生产环境的成本
#### 实际应用示例
您可以参考以下示例来实现token统计
- `examples/lightrag_gemini_track_token_demo.py`使用Google Gemini模型的token统计示例
- `examples/lightrag_siliconcloud_track_token_demo.py`使用SiliconCloud模型的token统计示例
这些示例展示了如何在不同模型和场景下有效地使用TokenTracker功能。
</details>
### 对话历史 ### 对话历史
LightRAG现在通过对话历史功能支持多轮对话。以下是使用方法 LightRAG现在通过对话历史功能支持多轮对话。以下是使用方法
@@ -1037,9 +1084,10 @@ rag.clear_cache(modes=["local"])
| **参数** | **类型** | **说明** | **默认值** | | **参数** | **类型** | **说明** | **默认值** |
|--------------|----------|-----------------|-------------| |--------------|----------|-----------------|-------------|
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` | | **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
| **kv_storage** | `str` | 文档和文本块的存储类型。支持的类型:`JsonKVStorage``OracleKVStorage` | `JsonKVStorage` | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | 嵌入向量的存储类型。支持的类型:`NanoVectorDBStorage`、`OracleVectorDBStorage` | `NanoVectorDBStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
| **graph_storage** | `str` | 图边和节点的存储类型。支持的类型:`NetworkXStorage``Neo4JStorage``OracleGraphStorage` | `NetworkXStorage` | | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` | | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` | | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` | | **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |

View File

@@ -41,7 +41,6 @@
- [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
- [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise. - [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
- [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author. - [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
- [X] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
- [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
- [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge. - [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
- [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage). - [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
@@ -443,6 +442,55 @@ if __name__ == "__main__":
</details> </details>
### Token Usage Tracking
<details>
<summary> <b>Overview and Usage</b> </summary>
LightRAG provides a TokenTracker tool to monitor and manage token consumption by large language models. This feature is particularly useful for controlling API costs and optimizing performance.
#### Usage
```python
from lightrag.utils import TokenTracker
# Create TokenTracker instance
token_tracker = TokenTracker()
# Method 1: Using context manager (Recommended)
# Suitable for scenarios requiring automatic token usage tracking
with token_tracker:
result1 = await llm_model_func("your question 1")
result2 = await llm_model_func("your question 2")
# Method 2: Manually adding token usage records
# Suitable for scenarios requiring more granular control over token statistics
token_tracker.reset()
rag.insert()
rag.query("your question 1", param=QueryParam(mode="naive"))
rag.query("your question 2", param=QueryParam(mode="mix"))
# Display total token usage (including insert and query operations)
print("Token usage:", token_tracker.get_usage())
```
#### Usage Tips
- Use context managers for long sessions or batch operations to automatically track all token consumption
- For scenarios requiring segmented statistics, use manual mode and call reset() when appropriate
- Regular checking of token usage helps detect abnormal consumption early
- Actively use this feature during development and testing to optimize production costs
#### Practical Examples
You can refer to these examples for implementing token tracking:
- `examples/lightrag_gemini_track_token_demo.py`: Token tracking example using Google Gemini model
- `examples/lightrag_siliconcloud_track_token_demo.py`: Token tracking example using SiliconCloud model
These examples demonstrate how to effectively use the TokenTracker feature with different models and scenarios.
</details>
### Conversation History Support ### Conversation History Support
@@ -607,7 +655,7 @@ The `apipeline_enqueue_documents` and `apipeline_process_enqueue_documents` func
This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing. This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing.
And using a routine to process news documents. And using a routine to process new documents.
```python ```python
rag = LightRAG(..) rag = LightRAG(..)
@@ -1096,9 +1144,10 @@ Valid modes are:
| **Parameter** | **Type** | **Explanation** | **Default** | | **Parameter** | **Type** | **Explanation** | **Default** |
|--------------|----------|-----------------|-------------| |--------------|----------|-----------------|-------------|
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | | **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |

View File

@@ -13,23 +13,6 @@ uri=redis://localhost:6379/1
[qdrant] [qdrant]
uri = http://localhost:16333 uri = http://localhost:16333
[oracle]
dsn = localhost:1521/XEPDB1
user = your_username
password = your_password
config_dir = /path/to/oracle/config
wallet_location = /path/to/wallet # 可选
wallet_password = your_wallet_password # 可选
workspace = default # 可选,默认为default
[tidb]
host = localhost
port = 4000
user = your_username
password = your_password
database = your_database
workspace = default # 可选,默认为default
[postgres] [postgres]
host = localhost host = localhost
port = 5432 port = 5432

View File

@@ -4,11 +4,9 @@
# HOST=0.0.0.0 # HOST=0.0.0.0
# PORT=9621 # PORT=9621
# WORKERS=2 # WORKERS=2
### separating data from difference Lightrag instances
# NAMESPACE_PREFIX=lightrag
### Max nodes return from grap retrieval
# MAX_GRAPH_NODES=1000
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080 # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
WEBUI_TITLE='Graph RAG Engine'
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
### Optional SSL Configuration ### Optional SSL Configuration
# SSL=true # SSL=true
@@ -22,6 +20,9 @@
### Ollama Emulating Model Tag ### Ollama Emulating Model Tag
# OLLAMA_EMULATING_MODEL_TAG=latest # OLLAMA_EMULATING_MODEL_TAG=latest
### Max nodes return from grap retrieval
# MAX_GRAPH_NODES=1000
### Logging level ### Logging level
# LOG_LEVEL=INFO # LOG_LEVEL=INFO
# VERBOSE=False # VERBOSE=False
@@ -110,24 +111,14 @@ LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage
LIGHTRAG_GRAPH_STORAGE=NetworkXStorage LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
### Oracle Database Configuration ### TiDB Configuration (Deprecated)
ORACLE_DSN=localhost:1521/XEPDB1 # TIDB_HOST=localhost
ORACLE_USER=your_username # TIDB_PORT=4000
ORACLE_PASSWORD='your_password' # TIDB_USER=your_username
ORACLE_CONFIG_DIR=/path/to/oracle/config # TIDB_PASSWORD='your_password'
#ORACLE_WALLET_LOCATION=/path/to/wallet # TIDB_DATABASE=your_database
#ORACLE_WALLET_PASSWORD='your_password' ### separating all data from difference Lightrag instances(deprecating)
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) # TIDB_WORKSPACE=default
#ORACLE_WORKSPACE=default
### TiDB Configuration
TIDB_HOST=localhost
TIDB_PORT=4000
TIDB_USER=your_username
TIDB_PASSWORD='your_password'
TIDB_DATABASE=your_database
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
#TIDB_WORKSPACE=default
### PostgreSQL Configuration ### PostgreSQL Configuration
POSTGRES_HOST=localhost POSTGRES_HOST=localhost
@@ -135,8 +126,8 @@ POSTGRES_PORT=5432
POSTGRES_USER=your_username POSTGRES_USER=your_username
POSTGRES_PASSWORD='your_password' POSTGRES_PASSWORD='your_password'
POSTGRES_DATABASE=your_database POSTGRES_DATABASE=your_database
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) ### separating all data from difference Lightrag instances(deprecating)
#POSTGRES_WORKSPACE=default # POSTGRES_WORKSPACE=default
### Independent AGM Configuration(not for AMG embedded in PostreSQL) ### Independent AGM Configuration(not for AMG embedded in PostreSQL)
AGE_POSTGRES_DB= AGE_POSTGRES_DB=
@@ -145,8 +136,8 @@ AGE_POSTGRES_PASSWORD=
AGE_POSTGRES_HOST= AGE_POSTGRES_HOST=
# AGE_POSTGRES_PORT=8529 # AGE_POSTGRES_PORT=8529
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
# AGE Graph Name(apply to PostgreSQL and independent AGM) # AGE Graph Name(apply to PostgreSQL and independent AGM)
### AGE_GRAPH_NAME is precated
# AGE_GRAPH_NAME=lightrag # AGE_GRAPH_NAME=lightrag
### Neo4j Configuration ### Neo4j Configuration
@@ -157,7 +148,7 @@ NEO4J_PASSWORD='your_password'
### MongoDB Configuration ### MongoDB Configuration
MONGO_URI=mongodb://root:root@localhost:27017/ MONGO_URI=mongodb://root:root@localhost:27017/
MONGO_DATABASE=LightRAG MONGO_DATABASE=LightRAG
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) ### separating all data from difference Lightrag instances(deprecating)
# MONGODB_GRAPH=false # MONGODB_GRAPH=false
### Milvus Configuration ### Milvus Configuration
@@ -177,7 +168,9 @@ REDIS_URI=redis://localhost:6379
### For JWT Auth ### For JWT Auth
# AUTH_ACCOUNTS='admin:admin123,user1:pass456' # AUTH_ACCOUNTS='admin:admin123,user1:pass456'
# TOKEN_SECRET=Your-Key-For-LightRAG-API-Server # TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
# TOKEN_EXPIRE_HOURS=4 # TOKEN_EXPIRE_HOURS=48
# GUEST_TOKEN_EXPIRE_HOURS=24
# JWT_ALGORITHM=HS256
### API-Key to access LightRAG Server API ### API-Key to access LightRAG Server API
# LIGHTRAG_API_KEY=your-secure-api-key-here # LIGHTRAG_API_KEY=your-secure-api-key-here

View File

@@ -1,188 +0,0 @@
from fastapi import FastAPI, HTTPException, File, UploadFile
from contextlib import asynccontextmanager
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.ollama import ollama_embed, ollama_model_complete
from lightrag.utils import EmbeddingFunc
from typing import Optional
import asyncio
import nest_asyncio
import aiofiles
from lightrag.kg.shared_storage import initialize_pipeline_status
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
DEFAULT_RAG_DIR = "index_default"
DEFAULT_INPUT_FILE = "book.txt"
INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}")
print(f"INPUT_FILE: {INPUT_FILE}")
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def init():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="gemma2:9b",
llm_model_max_async=4,
llm_model_max_token_size=8192,
llm_model_kwargs={
"host": "http://localhost:11434",
"options": {"num_ctx": 8192},
},
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embed(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),
)
# Add initialization code
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
rag = await init()
print("done!")
yield
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
# Data models
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
only_need_context: bool = False
class InsertRequest(BaseModel):
text: str
class Response(BaseModel):
status: str
data: Optional[str] = None
message: Optional[str] = None
# API routes
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: rag.query(
request.query,
param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context
),
),
)
return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# insert by text
@app.post("/insert", response_model=Response)
async def insert_endpoint(request: InsertRequest):
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(request.text))
return Response(status="success", message="Text inserted successfully")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# insert by file in payload
@app.post("/insert_file", response_model=Response)
async def insert_file(file: UploadFile = File(...)):
try:
file_content = await file.read()
# Read file content
try:
content = file_content.decode("utf-8")
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
content = file_content.decode("gbk")
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
return Response(
status="success",
message=f"File content from {file.filename} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# insert by local default file
@app.post("/insert_default_file", response_model=Response)
@app.get("/insert_default_file", response_model=Response)
async def insert_default_file():
try:
# Read file content from book.txt
async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file:
content = await file.read()
print(f"read input file {INPUT_FILE} successfully")
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
return Response(
status="success",
message=f"File content from {INPUT_FILE} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)
# Usage example
# To run the server, use the following command in your terminal:
# python lightrag_api_openai_compatible_demo.py
# Example requests:
# 1. Query:
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
# 2. Insert text:
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -1,204 +0,0 @@
from fastapi import FastAPI, HTTPException, File, UploadFile
from contextlib import asynccontextmanager
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from typing import Optional
import asyncio
import nest_asyncio
from lightrag.kg.shared_storage import initialize_pipeline_status
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
DEFAULT_RAG_DIR = "index_default"
app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
print(f"BASE_URL: {BASE_URL}")
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
print(f"API_KEY: {API_KEY}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# LLM model function
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
model=LLM_MODEL,
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=BASE_URL,
api_key=API_KEY,
**kwargs,
)
# Embedding function
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts=texts,
model=EMBEDDING_MODEL,
base_url=BASE_URL,
api_key=API_KEY,
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"{embedding_dim=}")
return embedding_dim
# Initialize RAG instance
async def init():
embedding_dimension = await get_embedding_dim()
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
rag = await init()
print("done!")
yield
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
# Data models
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
only_need_context: bool = False
class InsertRequest(BaseModel):
text: str
class Response(BaseModel):
status: str
data: Optional[str] = None
message: Optional[str] = None
# API routes
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: rag.query(
request.query,
param=QueryParam(
mode=request.mode, only_need_context=request.only_need_context
),
),
)
return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/insert", response_model=Response)
async def insert_endpoint(request: InsertRequest):
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(request.text))
return Response(status="success", message="Text inserted successfully")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/insert_file", response_model=Response)
async def insert_file(file: UploadFile = File(...)):
try:
file_content = await file.read()
# Read file content
try:
content = file_content.decode("utf-8")
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
content = file_content.decode("gbk")
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
return Response(
status="success",
message=f"File content from {file.filename} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)
# Usage example
# To run the server, use the following command in your terminal:
# python lightrag_api_openai_compatible_demo.py
# Example requests:
# 1. Query:
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
# 2. Insert text:
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -1,267 +0,0 @@
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi import Query
from contextlib import asynccontextmanager
from pydantic import BaseModel
from typing import Optional, Any
import sys
import os
from pathlib import Path
import asyncio
import nest_asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
DEFAULT_RAG_DIR = "index_default"
# We use OpenAI compatible API to call LLM on Oracle Cloud
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
APIKEY = "ocigenerativeai"
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
os.environ["ORACLE_USER"] = ""
os.environ["ORACLE_PASSWORD"] = ""
os.environ["ORACLE_DSN"] = ""
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
os.environ["ORACLE_WORKSPACE"] = "company"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
LLM_MODEL,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=APIKEY,
base_url=BASE_URL,
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model=EMBEDDING_MODEL,
api_key=APIKEY,
base_url=BASE_URL,
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim
async def init():
# Detect embedding dimension
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
rag = LightRAG(
enable_llm_cache=False,
working_dir=WORKING_DIR,
chunk_token_size=512,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=512,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
# Extract and Insert into LightRAG storage
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform search in different modes
# modes = ["naive", "local", "global", "hybrid"]
# for mode in modes:
# print("="*20, mode, "="*20)
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
# print("-"*100, "\n")
# Data models
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
only_need_context: bool = False
only_need_prompt: bool = False
class DataRequest(BaseModel):
limit: int = 100
class InsertRequest(BaseModel):
text: str
class Response(BaseModel):
status: str
data: Optional[Any] = None
message: Optional[str] = None
# API routes
rag = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag
rag = await init()
print("done!")
yield
app = FastAPI(
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
)
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
# try:
# loop = asyncio.get_event_loop()
if request.mode == "naive":
top_k = 3
else:
top_k = 60
result = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
only_need_context=request.only_need_context,
only_need_prompt=request.only_need_prompt,
top_k=top_k,
),
)
return Response(status="success", data=result)
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
@app.get("/data", response_model=Response)
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
if type == "nodes":
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
elif type == "edges":
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
elif type == "statistics":
result = await rag.chunk_entity_relation_graph.get_statistics()
return Response(status="success", data=result)
@app.post("/insert", response_model=Response)
async def insert_endpoint(request: InsertRequest):
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(request.text))
return Response(status="success", message="Text inserted successfully")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/insert_file", response_model=Response)
async def insert_file(file: UploadFile = File(...)):
try:
file_content = await file.read()
# Read file content
try:
content = file_content.decode("utf-8")
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
content = file_content.decode("gbk")
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
return Response(
status="success",
message=f"File content from {file.filename} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8020)
# Usage example
# To run the server, use the following command in your terminal:
# python lightrag_api_openai_compatible_demo.py
# Example requests:
# 1. Query:
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
# 2. Insert text:
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -1,3 +1,7 @@
##############################################
# Gremlin storage implementation is deprecated
##############################################
import asyncio import asyncio
import inspect import inspect
import os import os

View File

@@ -1,141 +0,0 @@
import sys
import os
from pathlib import Path
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
WORKING_DIR = "./dickens"
# We use OpenAI compatible API to call LLM on Oracle Cloud
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
APIKEY = "ocigenerativeai"
CHATMODEL = "cohere.command-r-plus"
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
CHUNK_TOKEN_SIZE = 1024
MAX_TOKENS = 4000
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
os.environ["ORACLE_USER"] = "username"
os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
os.environ["ORACLE_WORKSPACE"] = "company"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
CHATMODEL,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=APIKEY,
base_url=BASE_URL,
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model=EMBEDMODEL,
api_key=APIKEY,
base_url=BASE_URL,
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim
async def initialize_rag():
# Detect embedding dimension
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
# log_level="DEBUG",
working_dir=WORKING_DIR,
entity_extract_max_gleaning=1,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=True,
embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size=MAX_TOKENS,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=500,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params={
"example_number": 1,
"language": "Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size": 2,
},
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def main():
try:
# Initialize RAG instance
rag = await initialize_rag()
# Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read()
texts = [x for x in all_text.split("\n") if x]
# New mode use pipeline
await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents()
# Old method use ainsert
# await rag.ainsert(texts)
# Perform search in different modes
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:
print("=" * 20, mode, "=" * 20)
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode=mode),
)
)
print("-" * 100, "\n")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,3 +1,7 @@
###########################################
# TiDB storage implementation is deprecated
###########################################
import asyncio import asyncio
import os import os

View File

@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
``` ```
JsonKVStorage JsonFile(默认) JsonKVStorage JsonFile(默认)
MongoKVStorage MogonDB
RedisKVStorage Redis
TiDBKVStorage TiDB
PGKVStorage Postgres PGKVStorage Postgres
OracleKVStorage Oracle RedisKVStorage Redis
MongoKVStorage MogonDB
``` ```
* GRAPH_STORAGE 支持的实现名称 * GRAPH_STORAGE 支持的实现名称
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
``` ```
NetworkXStorage NetworkX(默认) NetworkXStorage NetworkX(默认)
Neo4JStorage Neo4J Neo4JStorage Neo4J
MongoGraphStorage MongoDB
TiDBGraphStorage TiDB
AGEStorage AGE
GremlinStorage Gremlin
PGGraphStorage Postgres PGGraphStorage Postgres
OracleGraphStorage Postgres AGEStorage AGE
``` ```
* VECTOR_STORAGE 支持的实现名称 * VECTOR_STORAGE 支持的实现名称
``` ```
NanoVectorDBStorage NanoVector(默认) NanoVectorDBStorage NanoVector(默认)
PGVectorStorage Postgres
MilvusVectorDBStorge Milvus MilvusVectorDBStorge Milvus
ChromaVectorDBStorage Chroma ChromaVectorDBStorage Chroma
TiDBVectorDBStorage TiDB
PGVectorStorage Postgres
FaissVectorDBStorage Faiss FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant QdrantVectorDBStorage Qdrant
OracleVectorDBStorage Oracle
MongoVectorDBStorage MongoDB MongoVectorDBStorage MongoDB
``` ```

View File

@@ -302,11 +302,9 @@ Each storage type have servals implementations:
``` ```
JsonKVStorage JsonFile(default) JsonKVStorage JsonFile(default)
MongoKVStorage MogonDB
RedisKVStorage Redis
TiDBKVStorage TiDB
PGKVStorage Postgres PGKVStorage Postgres
OracleKVStorage Oracle RedisKVStorage Redis
MongoKVStorage MogonDB
``` ```
* GRAPH_STORAGE supported implement-name * GRAPH_STORAGE supported implement-name
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
``` ```
NetworkXStorage NetworkX(defualt) NetworkXStorage NetworkX(defualt)
Neo4JStorage Neo4J Neo4JStorage Neo4J
MongoGraphStorage MongoDB
TiDBGraphStorage TiDB
AGEStorage AGE
GremlinStorage Gremlin
PGGraphStorage Postgres PGGraphStorage Postgres
OracleGraphStorage Postgres AGEStorage AGE
``` ```
* VECTOR_STORAGE supported implement-name * VECTOR_STORAGE supported implement-name
``` ```
NanoVectorDBStorage NanoVector(default) NanoVectorDBStorage NanoVector(default)
MilvusVectorDBStorage Milvus
ChromaVectorDBStorage Chroma
TiDBVectorDBStorage TiDB
PGVectorStorage Postgres PGVectorStorage Postgres
MilvusVectorDBStorge Milvus
ChromaVectorDBStorage Chroma
FaissVectorDBStorage Faiss FaissVectorDBStorage Faiss
QdrantVectorDBStorage Qdrant QdrantVectorDBStorage Qdrant
OracleVectorDBStorage Oracle
MongoVectorDBStorage MongoDB MongoVectorDBStorage MongoDB
``` ```

View File

@@ -1 +1 @@
__api_version__ = "1.2.8" __api_version__ = "0136"

View File

@@ -1,9 +1,11 @@
import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
import jwt import jwt
from dotenv import load_dotenv
from fastapi import HTTPException, status from fastapi import HTTPException, status
from pydantic import BaseModel from pydantic import BaseModel
from dotenv import load_dotenv
from .config import global_args
# use the .env that is inside the current folder # use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance # allows to use different .env file for each lightrag instance
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
class AuthHandler: class AuthHandler:
def __init__(self): def __init__(self):
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46") self.secret = global_args.token_secret
self.algorithm = "HS256" self.algorithm = global_args.jwt_algorithm
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4)) self.expire_hours = global_args.token_expire_hours
self.guest_expire_hours = int(os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2)) self.guest_expire_hours = global_args.guest_token_expire_hours
self.accounts = {} self.accounts = {}
auth_accounts = os.getenv("AUTH_ACCOUNTS") auth_accounts = global_args.auth_accounts
if auth_accounts: if auth_accounts:
for account in auth_accounts.split(","): for account in auth_accounts.split(","):
username, password = account.split(":", 1) username, password = account.split(":", 1)

335
lightrag/api/config.py Normal file
View File

@@ -0,0 +1,335 @@
"""
Configs for the LightRAG API.
"""
import os
import argparse
import logging
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
ollama_server_infos = OllamaServerInfos()
class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
) # fallback to ollama if unknown
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
"""
Get value from environment variable with type conversion
Args:
env_key (str): Environment variable key
default (any): Default value if env variable is not set
value_type (type): Type to convert the value to
Returns:
any: Converted value from environment or default
"""
value = os.getenv(env_key)
if value is None:
return default
if value_type is bool:
return value.lower() in ("true", "1", "yes", "t", "on")
try:
return value_type(value)
except ValueError:
return default
def parse_args() -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Args:
is_uvicorn_mode: Whether running under uvicorn mode
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration
parser.add_argument(
"--host",
default=get_env_value("HOST", "0.0.0.0"),
help="Server host (default: from env or 0.0.0.0)",
)
parser.add_argument(
"--port",
type=int,
default=get_env_value("PORT", 9621, int),
help="Server port (default: from env or 9621)",
)
# Directory configuration
parser.add_argument(
"--working-dir",
default=get_env_value("WORKING_DIR", "./rag_storage"),
help="Working directory for RAG storage (default: from env or ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default=get_env_value("INPUT_DIR", "./inputs"),
help="Directory containing input documents (default: from env or ./inputs)",
)
def timeout_type(value):
if value is None:
return 150
if value is None or value == "None":
return None
return int(value)
parser.add_argument(
"--timeout",
default=get_env_value("TIMEOUT", None, timeout_type),
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async",
type=int,
default=get_env_value("MAX_ASYNC", 4, int),
help="Maximum async operations (default: from env or 4)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=get_env_value("MAX_TOKENS", 32768, int),
help="Maximum token size (default: from env or 32768)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default=get_env_value("LOG_LEVEL", "INFO"),
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: from env or INFO)",
)
parser.add_argument(
"--verbose",
action="store_true",
default=get_env_value("VERBOSE", False, bool),
help="Enable verbose debug output(only valid for DEBUG log-level)",
)
parser.add_argument(
"--key",
type=str,
default=get_env_value("LIGHTRAG_API_KEY", None),
help="API key for authentication. This protects lightrag server against unauthorized access",
)
# Optional https parameters
parser.add_argument(
"--ssl",
action="store_true",
default=get_env_value("SSL", False, bool),
help="Enable HTTPS (default: from env or False)",
)
parser.add_argument(
"--ssl-certfile",
default=get_env_value("SSL_CERTFILE", None),
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=get_env_value("SSL_KEYFILE", None),
help="Path to SSL private key file (required if --ssl is enabled)",
)
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", 3, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)",
)
# Ollama model name
parser.add_argument(
"--simulated-model-name",
type=str,
default=get_env_value(
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Namespace
parser.add_argument(
"--namespace-prefix",
type=str,
default=get_env_value("NAMESPACE_PREFIX", ""),
help="Prefix of the namespace",
)
parser.add_argument(
"--auto-scan-at-startup",
action="store_true",
default=False,
help="Enable automatic scanning when the program starts",
)
# Server workers configuration
parser.add_argument(
"--workers",
type=int,
default=get_env_value("WORKERS", 1, int),
help="Number of worker processes (default: from env or 1)",
)
# LLM and embedding bindings
parser.add_argument(
"--llm-binding",
type=str,
default=get_env_value("LLM_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
help="LLM binding type (default: from env or ollama)",
)
parser.add_argument(
"--embedding-binding",
type=str,
default=get_env_value("EMBEDDING_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "azure_openai"],
help="Embedding binding type (default: from env or ollama)",
)
args = parser.parse_args()
# convert relative path to absolute path
args.working_dir = os.path.abspath(args.working_dir)
args.input_dir = os.path.abspath(args.input_dir)
# Inject storage configuration from environment variables
args.kv_storage = get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
)
args.doc_status_storage = get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
)
args.graph_storage = get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
)
args.vector_storage = get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
)
# Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
# Handle openai-ollama special case
if args.llm_binding == "openai-ollama":
args.llm_binding = "openai"
args.embedding_binding = "ollama"
args.llm_binding_host = get_env_value(
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
)
args.embedding_binding_host = get_env_value(
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
)
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
# Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
# Inject chunk configuration
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
)
# Inject LLM temperature configuration
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
# Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
# Add environment variables that were previously read directly
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
# For JWT Auth
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args
def update_uvicorn_mode_config():
# If in uvicorn mode and workers > 1, force it to 1 and log warning
if global_args.workers > 1:
original_workers = global_args.workers
global_args.workers = 1
# Log warning directly here
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
)
global_args = parse_args()

View File

@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag.api.utils_api import ( from lightrag.api.utils_api import (
get_combined_auth_dependency, get_combined_auth_dependency,
parse_args,
get_default_host,
display_splash_screen, display_splash_screen,
check_env_file, check_env_file,
) )
from .config import (
global_args,
update_uvicorn_mode_config,
get_default_host,
)
import sys import sys
from lightrag import LightRAG, __version__ as core_version from lightrag import LightRAG, __version__ as core_version
from lightrag.api import __api_version__ from lightrag.api import __api_version__
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
webui_title = os.getenv("WEBUI_TITLE")
webui_description = os.getenv("WEBUI_DESCRIPTION")
# Initialize config parser # Initialize config parser
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini") config.read("config.ini")
@@ -164,10 +171,10 @@ def create_app(args):
app = FastAPI(**app_kwargs) app = FastAPI(**app_kwargs)
def get_cors_origins(): def get_cors_origins():
"""Get allowed origins from environment variable """Get allowed origins from global_args
Returns a list of allowed origins, defaults to ["*"] if not set Returns a list of allowed origins, defaults to ["*"] if not set
""" """
origins_str = os.getenv("CORS_ORIGINS", "*") origins_str = global_args.cors_origins
if origins_str == "*": if origins_str == "*":
return ["*"] return ["*"]
return [origin.strip() for origin in origins_str.split(",")] return [origin.strip() for origin in origins_str.split(",")]
@@ -315,9 +322,10 @@ def create_app(args):
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
"use_llm_check": False, "use_llm_check": False,
}, },
namespace_prefix=args.namespace_prefix, # namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
addon_params={"language": args.summary_language},
) )
else: # azure_openai else: # azure_openai
rag = LightRAG( rag = LightRAG(
@@ -345,9 +353,10 @@ def create_app(args):
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
"use_llm_check": False, "use_llm_check": False,
}, },
namespace_prefix=args.namespace_prefix, # namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
addon_params={"language": args.summary_language},
) )
# Add routes # Add routes
@@ -381,6 +390,8 @@ def create_app(args):
"message": "Authentication is disabled. Using guest access.", "message": "Authentication is disabled. Using guest access.",
"core_version": core_version, "core_version": core_version,
"api_version": __api_version__, "api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
} }
return { return {
@@ -388,6 +399,8 @@ def create_app(args):
"auth_mode": "enabled", "auth_mode": "enabled",
"core_version": core_version, "core_version": core_version,
"api_version": __api_version__, "api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
} }
@app.post("/login") @app.post("/login")
@@ -404,6 +417,8 @@ def create_app(args):
"message": "Authentication is disabled. Using guest access.", "message": "Authentication is disabled. Using guest access.",
"core_version": core_version, "core_version": core_version,
"api_version": __api_version__, "api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
} }
username = form_data.username username = form_data.username
if auth_handler.accounts.get(username) != form_data.password: if auth_handler.accounts.get(username) != form_data.password:
@@ -421,6 +436,8 @@ def create_app(args):
"auth_mode": "enabled", "auth_mode": "enabled",
"core_version": core_version, "core_version": core_version,
"api_version": __api_version__, "api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
} }
@app.get("/health", dependencies=[Depends(combined_auth)]) @app.get("/health", dependencies=[Depends(combined_auth)])
@@ -454,10 +471,12 @@ def create_app(args):
"vector_storage": args.vector_storage, "vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
}, },
"core_version": core_version,
"api_version": __api_version__,
"auth_mode": auth_mode, "auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False), "pipeline_busy": pipeline_status.get("busy", False),
"core_version": core_version,
"api_version": __api_version__,
"webui_title": webui_title,
"webui_description": webui_description,
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting health status: {str(e)}") logger.error(f"Error getting health status: {str(e)}")
@@ -490,7 +509,7 @@ def create_app(args):
def get_application(args=None): def get_application(args=None):
"""Factory function for creating the FastAPI application""" """Factory function for creating the FastAPI application"""
if args is None: if args is None:
args = parse_args() args = global_args
return create_app(args) return create_app(args)
@@ -611,30 +630,31 @@ def main():
# Configure logging before parsing args # Configure logging before parsing args
configure_logging() configure_logging()
update_uvicorn_mode_config()
args = parse_args(is_uvicorn_mode=True) display_splash_screen(global_args)
display_splash_screen(args)
# Create application instance directly instead of using factory function # Create application instance directly instead of using factory function
app = create_app(args) app = create_app(global_args)
# Start Uvicorn in single process mode # Start Uvicorn in single process mode
uvicorn_config = { uvicorn_config = {
"app": app, # Pass application instance directly instead of string path "app": app, # Pass application instance directly instead of string path
"host": args.host, "host": global_args.host,
"port": args.port, "port": global_args.port,
"log_config": None, # Disable default config "log_config": None, # Disable default config
} }
if args.ssl: if global_args.ssl:
uvicorn_config.update( uvicorn_config.update(
{ {
"ssl_certfile": args.ssl_certfile, "ssl_certfile": global_args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile, "ssl_keyfile": global_args.ssl_keyfile,
} }
) )
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}") print(
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
)
uvicorn.run(**uvicorn_config) uvicorn.run(**uvicorn_config)

View File

@@ -10,16 +10,14 @@ import traceback
import pipmaster as pm import pipmaster as pm
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any, Literal
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.api.utils_api import ( from lightrag.api.utils_api import get_combined_auth_dependency
get_combined_auth_dependency, from ..config import global_args
global_args,
)
router = APIRouter( router = APIRouter(
prefix="/documents", prefix="/documents",
@@ -30,7 +28,37 @@ router = APIRouter(
temp_prefix = "__tmp__" temp_prefix = "__tmp__"
class ScanResponse(BaseModel):
"""Response model for document scanning operation
Attributes:
status: Status of the scanning operation
message: Optional message with additional details
"""
status: Literal["scanning_started"] = Field(
description="Status of the scanning operation"
)
message: Optional[str] = Field(
default=None, description="Additional details about the scanning operation"
)
class Config:
json_schema_extra = {
"example": {
"status": "scanning_started",
"message": "Scanning process has been initiated in the background",
}
}
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
"""Request model for inserting a single text document
Attributes:
text: The text content to be inserted into the RAG system
"""
text: str = Field( text: str = Field(
min_length=1, min_length=1,
description="The text to insert", description="The text to insert",
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
def strip_after(cls, text: str) -> str: def strip_after(cls, text: str) -> str:
return text.strip() return text.strip()
class Config:
json_schema_extra = {
"example": {
"text": "This is a sample text to be inserted into the RAG system."
}
}
class InsertTextsRequest(BaseModel): class InsertTextsRequest(BaseModel):
"""Request model for inserting multiple text documents
Attributes:
texts: List of text contents to be inserted into the RAG system
"""
texts: list[str] = Field( texts: list[str] = Field(
min_length=1, min_length=1,
description="The texts to insert", description="The texts to insert",
@@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel):
def strip_after(cls, texts: list[str]) -> list[str]: def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts] return [text.strip() for text in texts]
class Config:
json_schema_extra = {
"example": {
"texts": [
"This is the first text to be inserted.",
"This is the second text to be inserted.",
]
}
}
class InsertResponse(BaseModel): class InsertResponse(BaseModel):
status: str = Field(description="Status of the operation") """Response model for document insertion operations
Attributes:
status: Status of the operation (success, duplicated, partial_success, failure)
message: Detailed message describing the operation result
"""
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
description="Status of the operation"
)
message: str = Field(description="Message describing the operation result") message: str = Field(description="Message describing the operation result")
class Config:
json_schema_extra = {
"example": {
"status": "success",
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
}
}
class ClearDocumentsResponse(BaseModel):
"""Response model for document clearing operation
Attributes:
status: Status of the clear operation
message: Detailed message describing the operation result
"""
status: Literal["success", "partial_success", "busy", "fail"] = Field(
description="Status of the clear operation"
)
message: str = Field(description="Message describing the operation result")
class Config:
json_schema_extra = {
"example": {
"status": "success",
"message": "All documents cleared successfully. Deleted 15 files.",
}
}
class ClearCacheRequest(BaseModel):
"""Request model for clearing cache
Attributes:
modes: Optional list of cache modes to clear
"""
modes: Optional[
List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
] = Field(
default=None,
description="Modes of cache to clear. If None, clears all cache.",
)
class Config:
json_schema_extra = {"example": {"modes": ["default", "naive"]}}
class ClearCacheResponse(BaseModel):
"""Response model for cache clearing operation
Attributes:
status: Status of the clear operation
message: Detailed message describing the operation result
"""
status: Literal["success", "fail"] = Field(
description="Status of the clear operation"
)
message: str = Field(description="Message describing the operation result")
class Config:
json_schema_extra = {
"example": {
"status": "success",
"message": "Successfully cleared cache for modes: ['default', 'naive']",
}
}
"""Response model for document status
Attributes:
id: Document identifier
content_summary: Summary of document content
content_length: Length of document content
status: Current processing status
created_at: Creation timestamp (ISO format string)
updated_at: Last update timestamp (ISO format string)
chunks_count: Number of chunks (optional)
error: Error message if any (optional)
metadata: Additional metadata (optional)
file_path: Path to the document file
"""
class DocStatusResponse(BaseModel): class DocStatusResponse(BaseModel):
@staticmethod @staticmethod
@@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel):
return dt return dt
return dt.isoformat() return dt.isoformat()
"""Response model for document status id: str = Field(description="Document identifier")
content_summary: str = Field(description="Summary of document content")
content_length: int = Field(description="Length of document content in characters")
status: DocStatus = Field(description="Current processing status")
created_at: str = Field(description="Creation timestamp (ISO format string)")
updated_at: str = Field(description="Last update timestamp (ISO format string)")
chunks_count: Optional[int] = Field(
default=None, description="Number of chunks the document was split into"
)
error: Optional[str] = Field(
default=None, description="Error message if processing failed"
)
metadata: Optional[dict[str, Any]] = Field(
default=None, description="Additional metadata about the document"
)
file_path: str = Field(description="Path to the document file")
Attributes: class Config:
id: Document identifier json_schema_extra = {
content_summary: Summary of document content "example": {
content_length: Length of document content "id": "doc_123456",
status: Current processing status "content_summary": "Research paper on machine learning",
created_at: Creation timestamp (ISO format string) "content_length": 15240,
updated_at: Last update timestamp (ISO format string) "status": "PROCESSED",
chunks_count: Number of chunks (optional) "created_at": "2025-03-31T12:34:56",
error: Error message if any (optional) "updated_at": "2025-03-31T12:35:30",
metadata: Additional metadata (optional) "chunks_count": 12,
""" "error": None,
"metadata": {"author": "John Doe", "year": 2025},
id: str "file_path": "research_paper.pdf",
content_summary: str }
content_length: int }
status: DocStatus
created_at: str
updated_at: str
chunks_count: Optional[int] = None
error: Optional[str] = None
metadata: Optional[dict[str, Any]] = None
file_path: str
class DocsStatusesResponse(BaseModel): class DocsStatusesResponse(BaseModel):
statuses: Dict[DocStatus, List[DocStatusResponse]] = {} """Response model for document statuses
Attributes:
statuses: Dictionary mapping document status to lists of document status responses
"""
statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
default_factory=dict,
description="Dictionary mapping document status to lists of document status responses",
)
class Config:
json_schema_extra = {
"example": {
"statuses": {
"PENDING": [
{
"id": "doc_123",
"content_summary": "Pending document",
"content_length": 5000,
"status": "PENDING",
"created_at": "2025-03-31T10:00:00",
"updated_at": "2025-03-31T10:00:00",
"file_path": "pending_doc.pdf",
}
],
"PROCESSED": [
{
"id": "doc_456",
"content_summary": "Processed document",
"content_length": 8000,
"status": "PROCESSED",
"created_at": "2025-03-31T09:00:00",
"updated_at": "2025-03-31T09:05:00",
"chunks_count": 8,
"file_path": "processed_doc.pdf",
}
],
}
}
}
class PipelineStatusResponse(BaseModel): class PipelineStatusResponse(BaseModel):
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
) )
return False return False
case ".pdf": case ".pdf":
if global_args["main_args"].document_loading_engine == "DOCLING": if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore if not pm.is_installed("docling"): # type: ignore
pm.install("docling") pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore from docling.document_converter import DocumentConverter # type: ignore
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
for page in reader.pages: for page in reader.pages:
content += page.extract_text() + "\n" content += page.extract_text() + "\n"
case ".docx": case ".docx":
if global_args["main_args"].document_loading_engine == "DOCLING": if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore if not pm.is_installed("docling"): # type: ignore
pm.install("docling") pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore from docling.document_converter import DocumentConverter # type: ignore
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
[paragraph.text for paragraph in doc.paragraphs] [paragraph.text for paragraph in doc.paragraphs]
) )
case ".pptx": case ".pptx":
if global_args["main_args"].document_loading_engine == "DOCLING": if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore if not pm.is_installed("docling"): # type: ignore
pm.install("docling") pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore from docling.document_converter import DocumentConverter # type: ignore
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
if hasattr(shape, "text"): if hasattr(shape, "text"):
content += shape.text + "\n" content += shape.text + "\n"
case ".xlsx": case ".xlsx":
if global_args["main_args"].document_loading_engine == "DOCLING": if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore if not pm.is_installed("docling"): # type: ignore
pm.install("docling") pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore from docling.document_converter import DocumentConverter # type: ignore
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
await rag.apipeline_process_enqueue_documents() await rag.apipeline_process_enqueue_documents()
# TODO: deprecate after /insert_file is removed
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
"""Save the uploaded file to a temporary location """Save the uploaded file to a temporary location
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
if not new_files: if not new_files:
return return
# Get MAX_PARALLEL_INSERT from global_args["main_args"] # Get MAX_PARALLEL_INSERT from global_args
max_parallel = global_args["main_args"].max_parallel_insert max_parallel = global_args.max_parallel_insert
# Calculate batch size as 2 * MAX_PARALLEL_INSERT # Calculate batch size as 2 * MAX_PARALLEL_INSERT
batch_size = 2 * max_parallel batch_size = 2 * max_parallel
@@ -509,7 +704,9 @@ def create_document_routes(
# Create combined auth dependency for document routes # Create combined auth dependency for document routes
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.post("/scan", dependencies=[Depends(combined_auth)]) @router.post(
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
)
async def scan_for_new_documents(background_tasks: BackgroundTasks): async def scan_for_new_documents(background_tasks: BackgroundTasks):
""" """
Trigger the scanning process for new documents. Trigger the scanning process for new documents.
@@ -519,13 +716,18 @@ def create_document_routes(
that fact. that fact.
Returns: Returns:
dict: A dictionary containing the scanning status ScanResponse: A response object containing the scanning status
""" """
# Start the scanning process in the background # Start the scanning process in the background
background_tasks.add_task(run_scanning_process, rag, doc_manager) background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"} return ScanResponse(
status="scanning_started",
message="Scanning process has been initiated in the background",
)
@router.post("/upload", dependencies=[Depends(combined_auth)]) @router.post(
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
)
async def upload_to_input_dir( async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...) background_tasks: BackgroundTasks, file: UploadFile = File(...)
): ):
@@ -645,6 +847,7 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# TODO: deprecated, use /upload instead
@router.post( @router.post(
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)] "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
) )
@@ -688,6 +891,7 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# TODO: deprecated, use /upload instead
@router.post( @router.post(
"/file_batch", "/file_batch",
response_model=InsertResponse, response_model=InsertResponse,
@@ -752,32 +956,186 @@ def create_document_routes(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.delete( @router.delete(
"", response_model=InsertResponse, dependencies=[Depends(combined_auth)] "", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
) )
async def clear_documents(): async def clear_documents():
""" """
Clear all documents from the RAG system. Clear all documents from the RAG system.
This endpoint deletes all text chunks, entities vector database, and relationships This endpoint deletes all documents, entities, relationships, and files from the system.
vector database, effectively clearing all documents from the RAG system. It uses the storage drop methods to properly clean up all data and removes all files
from the input directory.
Returns: Returns:
InsertResponse: A response object containing the status and message. ClearDocumentsResponse: A response object containing the status and message.
- status="success": All documents and files were successfully cleared.
- status="partial_success": Document clear job exit with some errors.
- status="busy": Operation could not be completed because the pipeline is busy.
- status="fail": All storage drop operations failed, with message
- message: Detailed information about the operation results, including counts
of deleted files and any errors encountered.
Raises: Raises:
HTTPException: If an error occurs during the clearing process (500). HTTPException: Raised when a serious error occurs during the clearing process,
with status code 500 and error details in the detail field.
""" """
try: from lightrag.kg.shared_storage import (
rag.text_chunks = [] get_namespace_data,
rag.entities_vdb = None get_pipeline_status_lock,
rag.relationships_vdb = None )
return InsertResponse(
status="success", message="All documents cleared successfully" # Get pipeline status and lock
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status_lock = get_pipeline_status_lock()
# Check and set status with lock
async with pipeline_status_lock:
if pipeline_status.get("busy", False):
return ClearDocumentsResponse(
status="busy",
message="Cannot clear documents while pipeline is busy",
)
# Set busy to true
pipeline_status.update(
{
"busy": True,
"job_name": "Clearing Documents",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "Starting document clearing process",
}
) )
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
pipeline_status["history_messages"].append(
"Starting document clearing process"
)
try:
# Use drop method to clear all data
drop_tasks = []
storages = [
rag.text_chunks,
rag.full_docs,
rag.entities_vdb,
rag.relationships_vdb,
rag.chunks_vdb,
rag.chunk_entity_relation_graph,
rag.doc_status,
]
# Log storage drop start
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(
"Starting to drop storage components"
)
for storage in storages:
if storage is not None:
drop_tasks.append(storage.drop())
# Wait for all drop tasks to complete
drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
# Check for errors and log results
errors = []
storage_success_count = 0
storage_error_count = 0
for i, result in enumerate(drop_results):
storage_name = storages[i].__class__.__name__
if isinstance(result, Exception):
error_msg = f"Error dropping {storage_name}: {str(result)}"
errors.append(error_msg)
logger.error(error_msg)
storage_error_count += 1
else:
logger.info(f"Successfully dropped {storage_name}")
storage_success_count += 1
# Log storage drop results
if "history_messages" in pipeline_status:
if storage_error_count > 0:
pipeline_status["history_messages"].append(
f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
)
else:
pipeline_status["history_messages"].append(
f"Successfully dropped all {storage_success_count} storage components"
)
# If all storage operations failed, return error status and don't proceed with file deletion
if storage_success_count == 0 and storage_error_count > 0:
error_message = "All storage drop operations failed. Aborting document clearing process."
logger.error(error_message)
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(error_message)
return ClearDocumentsResponse(status="fail", message=error_message)
# Log file deletion start
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(
"Starting to delete files in input directory"
)
# Delete all files in input_dir
deleted_files_count = 0
file_errors_count = 0
for file_path in doc_manager.input_dir.glob("**/*"):
if file_path.is_file():
try:
file_path.unlink()
deleted_files_count += 1
except Exception as e:
logger.error(f"Error deleting file {file_path}: {str(e)}")
file_errors_count += 1
# Log file deletion results
if "history_messages" in pipeline_status:
if file_errors_count > 0:
pipeline_status["history_messages"].append(
f"Deleted {deleted_files_count} files with {file_errors_count} errors"
)
errors.append(f"Failed to delete {file_errors_count} files")
else:
pipeline_status["history_messages"].append(
f"Successfully deleted {deleted_files_count} files"
)
# Prepare final result message
final_message = ""
if errors:
final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
status = "partial_success"
else:
final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
status = "success"
# Log final result
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(final_message)
# Return response based on results
return ClearDocumentsResponse(status=status, message=final_message)
except Exception as e: except Exception as e:
logger.error(f"Error DELETE /documents: {str(e)}") error_msg = f"Error clearing documents: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(error_msg)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
# Reset busy status after completion
async with pipeline_status_lock:
pipeline_status["busy"] = False
completion_msg = "Document clearing process completed"
pipeline_status["latest_message"] = completion_msg
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].append(completion_msg)
@router.get( @router.get(
"/pipeline_status", "/pipeline_status",
@@ -850,7 +1208,9 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("", dependencies=[Depends(combined_auth)]) @router.get(
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
)
async def documents() -> DocsStatusesResponse: async def documents() -> DocsStatusesResponse:
""" """
Get the status of all documents in the system. Get the status of all documents in the system.
@@ -908,4 +1268,57 @@ def create_document_routes(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/clear_cache",
response_model=ClearCacheResponse,
dependencies=[Depends(combined_auth)],
)
async def clear_cache(request: ClearCacheRequest):
"""
Clear cache data from the LLM response cache storage.
This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
- "default" represents extraction cache.
- Other modes correspond to different query modes.
Args:
request (ClearCacheRequest): The request body containing optional modes to clear.
Returns:
ClearCacheResponse: A response object containing the status and message.
Raises:
HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
"""
try:
# Validate modes if provided
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
if request.modes and not all(mode in valid_modes for mode in request.modes):
invalid_modes = [
mode for mode in request.modes if mode not in valid_modes
]
raise HTTPException(
status_code=400,
detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
)
# Call the aclear_cache method
await rag.aclear_cache(request.modes)
# Prepare success message
if request.modes:
message = f"Successfully cleared cache for modes: {request.modes}"
else:
message = "Successfully cleared all cache"
return ClearCacheResponse(status="success", message=message)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error clearing cache: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
return router return router

View File

@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
""" """
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Query
from ..utils_api import get_combined_auth_dependency from ..utils_api import get_combined_auth_dependency
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graphs", dependencies=[Depends(combined_auth)]) @router.get("/graphs", dependencies=[Depends(combined_auth)])
async def get_knowledge_graph( async def get_knowledge_graph(
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False label: str = Query(..., description="Label to get knowledge graph for"),
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
): ):
""" """
Retrieve a connected subgraph of nodes where the label includes the specified label. Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes 1. Hops(path) to the staring node take precedence
2. Label matching nodes take precedence 2. Followed by the degree of the nodes
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args: Args:
label (str): Label to get knowledge graph for label (str): Label of the starting node
max_depth (int, optional): Maximum depth of graph. Defaults to 3. max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False. max_nodes: Maxiumu nodes to return
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
Returns: Returns:
Dict[str, List[str]]: Knowledge graph for label Dict[str, List[str]]: Knowledge graph for label
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
return await rag.get_knowledge_graph( return await rag.get_knowledge_graph(
node_label=label, node_label=label,
max_depth=max_depth, max_depth=max_depth,
inclusive=inclusive, max_nodes=max_nodes,
min_degree=min_degree,
) )
return router return router

View File

@@ -7,14 +7,9 @@ import os
import sys import sys
import signal import signal
import pipmaster as pm import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file from lightrag.api.utils_api import display_splash_screen, check_env_file
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
from dotenv import load_dotenv from .config import global_args
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
def check_and_install_dependencies(): def check_and_install_dependencies():
@@ -59,20 +54,17 @@ def main():
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # kill command signal.signal(signal.SIGTERM, signal_handler) # kill command
# Parse all arguments using parse_args
args = parse_args(is_uvicorn_mode=False)
# Display startup information # Display startup information
display_splash_screen(args) display_splash_screen(global_args)
print("🚀 Starting LightRAG with Gunicorn") print("🚀 Starting LightRAG with Gunicorn")
print(f"🔄 Worker management: Gunicorn (workers={args.workers})") print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
print("🔍 Preloading app: Enabled") print("🔍 Preloading app: Enabled")
print("📝 Note: Using Gunicorn's preload feature for shared data initialization") print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
print("\n\n" + "=" * 80) print("\n\n" + "=" * 80)
print("MAIN PROCESS INITIALIZATION") print("MAIN PROCESS INITIALIZATION")
print(f"Process ID: {os.getpid()}") print(f"Process ID: {os.getpid()}")
print(f"Workers setting: {args.workers}") print(f"Workers setting: {global_args.workers}")
print("=" * 80 + "\n") print("=" * 80 + "\n")
# Import Gunicorn's StandaloneApplication # Import Gunicorn's StandaloneApplication
@@ -128,31 +120,43 @@ def main():
# Set configuration variables in gunicorn_config, prioritizing command line arguments # Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = ( gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1)) global_args.workers
if global_args.workers
else int(os.getenv("WORKERS", 1))
) )
# Bind configuration prioritizes command line arguments # Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") host = (
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) global_args.host
if global_args.host != "0.0.0.0"
else os.getenv("HOST", "0.0.0.0")
)
port = (
global_args.port
if global_args.port != 9621
else int(os.getenv("PORT", 9621))
)
gunicorn_config.bind = f"{host}:{port}" gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments # Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = ( gunicorn_config.loglevel = (
args.log_level.lower() global_args.log_level.lower()
if args.log_level if global_args.log_level
else os.getenv("LOG_LEVEL", "info") else os.getenv("LOG_LEVEL", "info")
) )
# Timeout configuration prioritizes command line arguments # Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = ( gunicorn_config.timeout = (
args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2)) global_args.timeout
if global_args.timeout * 2
else int(os.getenv("TIMEOUT", 150 * 2))
) )
# Keepalive configuration # Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments # SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in ( if global_args.ssl or os.getenv("SSL", "").lower() in (
"true", "true",
"1", "1",
"yes", "yes",
@@ -160,12 +164,14 @@ def main():
"on", "on",
): ):
gunicorn_config.certfile = ( gunicorn_config.certfile = (
args.ssl_certfile global_args.ssl_certfile
if args.ssl_certfile if global_args.ssl_certfile
else os.getenv("SSL_CERTFILE") else os.getenv("SSL_CERTFILE")
) )
gunicorn_config.keyfile = ( gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") global_args.ssl_keyfile
if global_args.ssl_keyfile
else os.getenv("SSL_KEYFILE")
) )
# Set configuration options from the module # Set configuration options from the module
@@ -190,13 +196,13 @@ def main():
# Import the application # Import the application
from lightrag.api.lightrag_server import get_application from lightrag.api.lightrag_server import get_application
return get_application(args) return get_application(global_args)
# Create the application # Create the application
app = GunicornApp("") app = GunicornApp("")
# Force workers to be an integer and greater than 1 for multi-process mode # Force workers to be an integer and greater than 1 for multi-process mode
workers_count = int(args.workers) workers_count = int(global_args.workers)
if workers_count > 1: if workers_count > 1:
# Set a flag to indicate we're in the main process # Set a flag to indicate we're in the main process
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"

View File

@@ -7,15 +7,13 @@ import argparse
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
import sys import sys
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
import logging
from lightrag.api import __api_version__ as api_version from lightrag.api import __api_version__ as api_version
from lightrag import __version__ as core_version from lightrag import __version__ as core_version
from fastapi import HTTPException, Security, Request, status from fastapi import HTTPException, Security, Request, status
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler from .auth import auth_handler
from ..prompt import PROMPTS from .config import ollama_server_infos, global_args
def check_env_file(): def check_env_file():
@@ -36,16 +34,8 @@ def check_env_file():
return True return True
# use the .env that is inside the current folder # Get whitelist paths from global_args, only once during initialization
# allows to use different .env file for each lightrag instance whitelist_paths = global_args.whitelist_paths.split(",")
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
global_args = {"main_args": None}
# Get whitelist paths from environment variable, only once during initialization
default_whitelist = "/health,/api/*"
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
# Pre-compile path matching patterns # Pre-compile path matching patterns
whitelist_patterns: List[Tuple[str, bool]] = [] whitelist_patterns: List[Tuple[str, bool]] = []
@@ -63,19 +53,6 @@ for path in whitelist_paths:
auth_configured = bool(auth_handler.accounts) auth_configured = bool(auth_handler.accounts)
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
ollama_server_infos = OllamaServerInfos()
def get_combined_auth_dependency(api_key: Optional[str] = None): def get_combined_auth_dependency(api_key: Optional[str] = None):
""" """
Create a combined authentication dependency that implements authentication logic Create a combined authentication dependency that implements authentication logic
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
return combined_dependency return combined_dependency
class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
) # fallback to ollama if unknown
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
"""
Get value from environment variable with type conversion
Args:
env_key (str): Environment variable key
default (any): Default value if env variable is not set
value_type (type): Type to convert the value to
Returns:
any: Converted value from environment or default
"""
value = os.getenv(env_key)
if value is None:
return default
if value_type is bool:
return value.lower() in ("true", "1", "yes", "t", "on")
try:
return value_type(value)
except ValueError:
return default
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Args:
is_uvicorn_mode: Whether running under uvicorn mode
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration
parser.add_argument(
"--host",
default=get_env_value("HOST", "0.0.0.0"),
help="Server host (default: from env or 0.0.0.0)",
)
parser.add_argument(
"--port",
type=int,
default=get_env_value("PORT", 9621, int),
help="Server port (default: from env or 9621)",
)
# Directory configuration
parser.add_argument(
"--working-dir",
default=get_env_value("WORKING_DIR", "./rag_storage"),
help="Working directory for RAG storage (default: from env or ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default=get_env_value("INPUT_DIR", "./inputs"),
help="Directory containing input documents (default: from env or ./inputs)",
)
def timeout_type(value):
if value is None:
return 150
if value is None or value == "None":
return None
return int(value)
parser.add_argument(
"--timeout",
default=get_env_value("TIMEOUT", None, timeout_type),
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async",
type=int,
default=get_env_value("MAX_ASYNC", 4, int),
help="Maximum async operations (default: from env or 4)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=get_env_value("MAX_TOKENS", 32768, int),
help="Maximum token size (default: from env or 32768)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default=get_env_value("LOG_LEVEL", "INFO"),
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: from env or INFO)",
)
parser.add_argument(
"--verbose",
action="store_true",
default=get_env_value("VERBOSE", False, bool),
help="Enable verbose debug output(only valid for DEBUG log-level)",
)
parser.add_argument(
"--key",
type=str,
default=get_env_value("LIGHTRAG_API_KEY", None),
help="API key for authentication. This protects lightrag server against unauthorized access",
)
# Optional https parameters
parser.add_argument(
"--ssl",
action="store_true",
default=get_env_value("SSL", False, bool),
help="Enable HTTPS (default: from env or False)",
)
parser.add_argument(
"--ssl-certfile",
default=get_env_value("SSL_CERTFILE", None),
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=get_env_value("SSL_KEYFILE", None),
help="Path to SSL private key file (required if --ssl is enabled)",
)
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", 3, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)",
)
# Ollama model name
parser.add_argument(
"--simulated-model-name",
type=str,
default=get_env_value(
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Namespace
parser.add_argument(
"--namespace-prefix",
type=str,
default=get_env_value("NAMESPACE_PREFIX", ""),
help="Prefix of the namespace",
)
parser.add_argument(
"--auto-scan-at-startup",
action="store_true",
default=False,
help="Enable automatic scanning when the program starts",
)
# Server workers configuration
parser.add_argument(
"--workers",
type=int,
default=get_env_value("WORKERS", 1, int),
help="Number of worker processes (default: from env or 1)",
)
# LLM and embedding bindings
parser.add_argument(
"--llm-binding",
type=str,
default=get_env_value("LLM_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
help="LLM binding type (default: from env or ollama)",
)
parser.add_argument(
"--embedding-binding",
type=str,
default=get_env_value("EMBEDDING_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "azure_openai"],
help="Embedding binding type (default: from env or ollama)",
)
args = parser.parse_args()
# If in uvicorn mode and workers > 1, force it to 1 and log warning
if is_uvicorn_mode and args.workers > 1:
original_workers = args.workers
args.workers = 1
# Log warning directly here
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
)
# convert relative path to absolute path
args.working_dir = os.path.abspath(args.working_dir)
args.input_dir = os.path.abspath(args.input_dir)
# Inject storage configuration from environment variables
args.kv_storage = get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
)
args.doc_status_storage = get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
)
args.graph_storage = get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
)
args.vector_storage = get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
)
# Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
# Handle openai-ollama special case
if args.llm_binding == "openai-ollama":
args.llm_binding = "openai"
args.embedding_binding = "ollama"
args.llm_binding_host = get_env_value(
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
)
args.embedding_binding_host = get_env_value(
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
)
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
# Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
# Inject chunk configuration
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
)
# Inject LLM temperature configuration
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
# Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
global_args["main_args"] = args
return args
def display_splash_screen(args: argparse.Namespace) -> None: def display_splash_screen(args: argparse.Namespace) -> None:
""" """
Display a colorful splash screen showing LightRAG server configuration Display a colorful splash screen showing LightRAG server configuration
@@ -489,7 +173,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
# Banner # Banner
ASCIIColors.cyan(f""" ASCIIColors.cyan(f"""
╔══════════════════════════════════════════════════════════════╗ ╔══════════════════════════════════════════════════════════════╗
🚀 LightRAG Server v{core_version}/{api_version} ║ 🚀 LightRAG Server v{core_version}/{api_version}
║ Fast, Lightweight RAG Server Implementation ║ ║ Fast, Lightweight RAG Server Implementation ║
╚══════════════════════════════════════════════════════════════╝ ╚══════════════════════════════════════════════════════════════╝
""") """)
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.white(" ├─ Workers: ", end="") ASCIIColors.white(" ├─ Workers: ", end="")
ASCIIColors.yellow(f"{args.workers}") ASCIIColors.yellow(f"{args.workers}")
ASCIIColors.white(" ├─ CORS Origins: ", end="") ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") ASCIIColors.yellow(f"{args.cors_origins}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="") ASCIIColors.white(" ├─ SSL Enabled: ", end="")
ASCIIColors.yellow(f"{args.ssl}") ASCIIColors.yellow(f"{args.ssl}")
if args.ssl: if args.ssl:
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.verbose}") ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" ├─ History Turns: ", end="") ASCIIColors.white(" ├─ History Turns: ", end="")
ASCIIColors.yellow(f"{args.history_turns}") ASCIIColors.yellow(f"{args.history_turns}")
ASCIIColors.white(" ─ API Key: ", end="") ASCIIColors.white(" ─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set") ASCIIColors.yellow("Set" if args.key else "Not Set")
ASCIIColors.white(" └─ JWT Auth: ", end="")
ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
# Directory Configuration # Directory Configuration
ASCIIColors.magenta("\n📂 Directory Configuration:") ASCIIColors.magenta("\n📂 Directory Configuration:")
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.embedding_dim}") ASCIIColors.yellow(f"{args.embedding_dim}")
# RAG Configuration # RAG Configuration
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
ASCIIColors.magenta("\n⚙️ RAG Configuration:") ASCIIColors.magenta("\n⚙️ RAG Configuration:")
ASCIIColors.white(" ├─ Summary Language: ", end="") ASCIIColors.white(" ├─ Summary Language: ", end="")
ASCIIColors.yellow(f"{summary_language}") ASCIIColors.yellow(f"{args.summary_language}")
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="") ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
ASCIIColors.yellow(f"{args.max_parallel_insert}") ASCIIColors.yellow(f"{args.max_parallel_insert}")
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
protocol = "https" if args.ssl else "http" protocol = "https" if args.ssl else "http"
if args.host == "0.0.0.0": if args.host == "0.0.0.0":
ASCIIColors.magenta("\n🌐 Server Access Information:") ASCIIColors.magenta("\n🌐 Server Access Information:")
ASCIIColors.white(" ├─ Local Access: ", end="") ASCIIColors.white(" ├─ WebUI (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}") ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
ASCIIColors.white(" ├─ Remote Access: ", end="") ASCIIColors.white(" ├─ Remote Access: ", end="")
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}") ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
ASCIIColors.white(" ├─ API Documentation (local): ", end="") ASCIIColors.white(" ├─ API Documentation (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs") ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
ASCIIColors.white(" ─ Alternative Documentation (local): ", end="") ASCIIColors.white(" ─ Alternative Documentation (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc") ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
ASCIIColors.white(" └─ WebUI (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
ASCIIColors.yellow("\n📝 Note:") ASCIIColors.magenta("\n📝 Note:")
ASCIIColors.white(""" Since the server is running on 0.0.0.0: ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
- Use 'localhost' or '127.0.0.1' for local access - Use 'localhost' or '127.0.0.1' for local access
- Use your machine's IP address for remote access - Use your machine's IP address for remote access
- To find your IP address: - To find your IP address:
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
else: else:
base_url = f"{protocol}://{args.host}:{args.port}" base_url = f"{protocol}://{args.host}:{args.port}"
ASCIIColors.magenta("\n🌐 Server Access Information:") ASCIIColors.magenta("\n🌐 Server Access Information:")
ASCIIColors.white(" ├─ Base URL: ", end="") ASCIIColors.white(" ├─ WebUI (local): ", end="")
ASCIIColors.yellow(f"{base_url}") ASCIIColors.yellow(f"{base_url}")
ASCIIColors.white(" ├─ API Documentation: ", end="") ASCIIColors.white(" ├─ API Documentation: ", end="")
ASCIIColors.yellow(f"{base_url}/docs") ASCIIColors.yellow(f"{base_url}/docs")
ASCIIColors.white(" └─ Alternative Documentation: ", end="") ASCIIColors.white(" └─ Alternative Documentation: ", end="")
ASCIIColors.yellow(f"{base_url}/redoc") ASCIIColors.yellow(f"{base_url}/redoc")
# Usage Examples
ASCIIColors.magenta("\n📚 Quick Start Guide:")
ASCIIColors.cyan("""
1. Access the Swagger UI:
Open your browser and navigate to the API documentation URL above
2. API Authentication:""")
if args.key:
ASCIIColors.cyan(""" Add the following header to your requests:
X-API-Key: <your-api-key>
""")
else:
ASCIIColors.cyan(" No authentication required\n")
ASCIIColors.cyan(""" 3. Basic Operations:
- POST /upload_document: Upload new documents to RAG
- POST /query: Query your document collection
4. Monitor the server:
- Check server logs for detailed operation information
- Use healthcheck endpoint: GET /health
""")
# Security Notice # Security Notice
if args.key: if args.key:
ASCIIColors.yellow("\n⚠️ Security Notice:") ASCIIColors.yellow("\n⚠️ Security Notice:")
ASCIIColors.white(""" API Key authentication is enabled. ASCIIColors.white(""" API Key authentication is enabled.
Make sure to include the X-API-Key header in all your requests. Make sure to include the X-API-Key header in all your requests.
""") """)
if args.auth_accounts:
ASCIIColors.yellow("\n⚠️ Security Notice:")
ASCIIColors.white(""" JWT authentication is enabled.
Make sure to login before making the request, and include the 'Authorization' in the header.
""")
# Ensure splash output flush to system log # Ensure splash output flush to system log
sys.stdout.flush() sys.stdout.flush()

File diff suppressed because one or more lines are too long

1345
lightrag/api/webui/assets/index-Cma7xY0-.js generated Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -8,8 +8,8 @@
<link rel="icon" type="image/svg+xml" href="logo.png" /> <link rel="icon" type="image/svg+xml" href="logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title> <title>Lightrag</title>
<script type="module" crossorigin src="/webui/assets/index-raheqJeu.js"></script> <script type="module" crossorigin src="/webui/assets/index-Cma7xY0-.js"></script>
<link rel="stylesheet" crossorigin href="/webui/assets/index-CD5HxTy1.css"> <link rel="stylesheet" crossorigin href="/webui/assets/index-QU59h9JG.css">
</head> </head>
<body> <body>
<div id="root"></div> <div id="root"></div>

View File

@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing""" """Commit the storage operations after indexing"""
@abstractmethod
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This abstract method defines the contract for dropping all data from a storage implementation.
Each storage type must implement this method to:
1. Clear all data from memory and/or external storage
2. Remove any associated storage files if applicable
3. Reset the storage to its initial state
4. Handle cleanup of any resources
5. Notify other processes if necessary
6. This action should persistent the data to disk immediately.
Returns:
dict[str, str]: Operation status and message with the following format:
{
"status": str, # "success" or "error"
"message": str # "data dropped" on success, error details on failure
}
Implementation specific:
- On success: return {"status": "success", "message": "data dropped"}
- On failure: return {"status": "error", "message": "<error details>"}
- If not supported: return {"status": "error", "message": "unsupported"}
"""
@dataclass @dataclass
class BaseVectorStorage(StorageNameSpace, ABC): class BaseVectorStorage(StorageNameSpace, ABC):
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Insert or update vectors in the storage.""" """Insert or update vectors in the storage.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod @abstractmethod
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name.""" """Delete a single entity by its name.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod @abstractmethod
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity.""" """Delete relations for a given entity.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod @abstractmethod
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
""" """
pass pass
@abstractmethod
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args:
ids: List of vector IDs to be deleted
"""
@dataclass @dataclass
class BaseKVStorage(StorageNameSpace, ABC): class BaseKVStorage(StorageNameSpace, ABC):
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""Upsert data""" """Upsert data
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
@abstractmethod
async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed, or the cache mode is not supported
"""
@dataclass @dataclass
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get an edge by its source and target node ids.""" """Get node by its label identifier, return only node properties"""
@abstractmethod @abstractmethod
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
"""Get all edges connected to a node.""" """Get edge properties between two nodes"""
@abstractmethod @abstractmethod
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
"""Delete a node from the graph.""" """Delete a node from the graph.
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
@abstractmethod @abstractmethod
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
@abstractmethod @abstractmethod
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 3 self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node.""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000BFS if possible)
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
class DocStatus(str, Enum): class DocStatus(str, Enum):
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Drop cache is not supported for Doc Status storage"""
return False
class StoragesStatus(str, Enum): class StoragesStatus(str, Enum):
"""Storages status""" """Storages status"""

View File

@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": { "KV_STORAGE": {
"implementations": [ "implementations": [
"JsonKVStorage", "JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage", "RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage", "PGKVStorage",
"OracleKVStorage", "MongoKVStorage",
# "TiDBKVStorage",
], ],
"required_methods": ["get_by_id", "upsert"], "required_methods": ["get_by_id", "upsert"],
}, },
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
"implementations": [ "implementations": [
"NetworkXStorage", "NetworkXStorage",
"Neo4JStorage", "Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage", "PGGraphStorage",
"OracleGraphStorage", # "AGEStorage",
# "MongoGraphStorage",
# "TiDBGraphStorage",
# "GremlinStorage",
], ],
"required_methods": ["upsert_node", "upsert_edge"], "required_methods": ["upsert_node", "upsert_edge"],
}, },
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
"NanoVectorDBStorage", "NanoVectorDBStorage",
"MilvusVectorDBStorage", "MilvusVectorDBStorage",
"ChromaVectorDBStorage", "ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage", "PGVectorStorage",
"FaissVectorDBStorage", "FaissVectorDBStorage",
"QdrantVectorDBStorage", "QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage", "MongoVectorDBStorage",
# "TiDBVectorDBStorage",
], ],
"required_methods": ["query", "upsert"], "required_methods": ["query", "upsert"],
}, },
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
"implementations": [ "implementations": [
"JsonDocStatusStorage", "JsonDocStatusStorage",
"PGDocStatusStorage", "PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage", "MongoDocStatusStorage",
], ],
"required_methods": ["get_docs_by_status"], "required_methods": ["get_docs_by_status"],
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"JsonKVStorage": [], "JsonKVStorage": [],
"MongoKVStorage": [], "MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"], "RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], # "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations # Graph Storage Implementations
"NetworkXStorage": [], "NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [], "MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [ "AGEStorage": [
"AGE_POSTGRES_DB", "AGE_POSTGRES_DB",
"AGE_POSTGRES_USER", "AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD", "AGE_POSTGRES_PASSWORD",
], ],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], # "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [ "PGGraphStorage": [
"POSTGRES_USER", "POSTGRES_USER",
"POSTGRES_PASSWORD", "POSTGRES_PASSWORD",
"POSTGRES_DATABASE", "POSTGRES_DATABASE",
], ],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations # Vector Storage Implementations
"NanoVectorDBStorage": [], "NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [], "MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [], "ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], # "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [], "FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [], "MongoVectorDBStorage": [],
# Document Status Storage Implementations # Document Status Storage Implementations
"JsonDocStatusStorage": [], "JsonDocStatusStorage": [],
@@ -112,9 +90,6 @@ STORAGES = {
"NanoVectorDBStorage": ".kg.nano_vector_db_impl", "NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl", "JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl", "Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl", "MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl", "MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl", "MongoDocStatusStorage": ".kg.mongo_impl",
@@ -122,14 +97,14 @@ STORAGES = {
"MongoVectorDBStorage": ".kg.mongo_impl", "MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl", "RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl", # "TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl", # "TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl", # "TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl", "PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl", "PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl", "AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl", "PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl", # "GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl", "PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl", "FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl",

View File

@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
if not pm.is_installed("asyncpg"): if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
import psycopg import psycopg # type: ignore
from psycopg.rows import namedtuple_row from psycopg.rows import namedtuple_row # type: ignore
from psycopg_pool import AsyncConnectionPool, PoolTimeout from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
class AGEQueryException(Exception): class AGEQueryException(Exception):
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# AGES handles persistence automatically # AGES handles persistence automatically
pass pass
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all nodes and relationships in the graph.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
query = """
MATCH (n)
DETACH DELETE n
"""
await self._query(query)
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.graph_name}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final from typing import Any, final
import numpy as np import numpy as np
@@ -10,8 +11,8 @@ import pipmaster as pm
if not pm.is_installed("chromadb"): if not pm.is_installed("chromadb"):
pm.install("chromadb") pm.install("chromadb")
from chromadb import HttpClient, PersistentClient from chromadb import HttpClient, PersistentClient # type: ignore
from chromadb.config import Settings from chromadb.config import Settings # type: ignore
@final @final
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return [] return []
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will delete all documents from the ChromaDB collection.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Get all IDs in the collection
result = self._collection.get(include=[])
if result and result["ids"] and len(result["ids"]) > 0:
# Delete all documents
self._collection.delete(ids=result["ids"])
logger.info(
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -11,16 +11,20 @@ import pipmaster as pm
from lightrag.utils import logger, compute_mdhash_id from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
if not pm.is_installed("faiss"):
pm.install("faiss")
import faiss # type: ignore
from .shared_storage import ( from .shared_storage import (
get_storage_lock, get_storage_lock,
get_update_flag, get_update_flag,
set_all_update_flags, set_all_update_flags,
) )
import faiss # type: ignore
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
if not pm.is_installed(FAISS_PACKAGE):
pm.install(FAISS_PACKAGE)
@final @final
@dataclass @dataclass
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
async def delete(self, ids: list[str]): async def delete(self, ids: list[str]):
""" """
Delete vectors for the provided custom IDs. Delete vectors for the provided custom IDs.
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
""" """
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
to_remove = [] to_remove = []
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
await self.delete([entity_id]) await self.delete([entity_id])
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
""" """
Delete relations for a given entity by scanning metadata. Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
""" """
logger.debug(f"Searching relations for entity {entity_name}") logger.debug(f"Searching relations for entity {entity_name}")
relations = [] relations = []
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
results.append({**metadata, "id": metadata.get("__id__")}) results.append({**metadata, "id": metadata.get("__id__")})
return results return results
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will:
1. Remove the vector database storage file if it exists
2. Reinitialize the vector database client
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
This method will remove all vectors from the Faiss index and delete the storage files.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# Reset the index
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
# Remove storage files if they exist
if os.path.exists(self._faiss_index_file):
os.remove(self._faiss_index_file)
if os.path.exists(self._meta_file):
os.remove(self._meta_file)
self._id_to_meta = {}
self._load_faiss_index()
# Notify other processes
await set_all_update_flags(self.namespace)
self.storage_updated.value = False
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
if not pm.is_installed("gremlinpython"): if not pm.is_installed("gremlinpython"):
pm.install("gremlinpython") pm.install("gremlinpython")
from gremlin_python.driver import client, serializer from gremlin_python.driver import client, serializer # type: ignore
from gremlin_python.driver.aiohttp.transport import AiohttpTransport from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
from gremlin_python.driver.protocol import GremlinServerError from gremlin_python.driver.protocol import GremlinServerError # type: ignore
@final @final
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
except Exception as e: except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"Error during edge deletion: {str(e)}")
raise raise
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all nodes and relationships in the graph.
This function deletes all nodes with the specified graph name property,
which automatically removes all associated edges.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
query = f"""g
.V().has('graph', {self.graph_name})
.drop()
"""
await self._query(query)
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.graph_name}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
await clear_all_update_flags(self.namespace) await clear_all_update_flags(self.namespace)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
if not data: if not data:
return return
logger.info(f"Inserting {len(data)} records to {self.namespace}") logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
async with self._storage_lock: async with self._storage_lock:
return self._data.get(id) return self._data.get(id)
async def delete(self, doc_ids: list[str]): async def delete(self, doc_ids: list[str]) -> None:
async with self._storage_lock: """Delete specific records from storage by their IDs
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await set_all_update_flags(self.namespace)
await self.index_done_callback()
async def drop(self) -> None: Importance notes for in-memory storage:
"""Drop the storage""" 1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async with self._storage_lock: async with self._storage_lock:
self._data.clear() any_deleted = False
await set_all_update_flags(self.namespace) for doc_id in doc_ids:
await self.index_done_callback() result = self._data.pop(doc_id, None)
if result is not None:
any_deleted = True
if any_deleted:
await set_all_update_flags(self.namespace)
async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources
This method will:
1. Clear all document status data from memory
2. Update flags to notify other processes
3. Trigger index_done_callback to save the empty state
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
self._data.clear()
await set_all_update_flags(self.namespace)
await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.namespace}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
"""
if not data: if not data:
return return
logger.info(f"Inserting {len(data)} records to {self.namespace}") logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
await set_all_update_flags(self.namespace) await set_all_update_flags(self.namespace)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of document IDs to be deleted from storage
Returns:
None
"""
async with self._storage_lock: async with self._storage_lock:
any_deleted = False
for doc_id in ids: for doc_id in ids:
self._data.pop(doc_id, None) result = self._data.pop(doc_id, None)
await set_all_update_flags(self.namespace) if result is not None:
await self.index_done_callback() any_deleted = True
if any_deleted:
await set_all_update_flags(self.namespace)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
Importance notes for in-memory storage:
1. Changes will be persisted to disk during the next index_done_callback
2. update flags to notify other processes that data persistence is needed
Args:
ids (list[str]): List of cache mode to be drop from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
try:
await self.delete(modes)
return True
except Exception:
return False
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This action will persistent the data to disk immediately.
This method will:
1. Clear all data from memory
2. Update flags to notify other processes
3. Trigger index_done_callback to save the empty state
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
self._data.clear()
await set_all_update_flags(self.namespace)
await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.namespace}")
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
pm.install("pymilvus") pm.install("pymilvus")
import configparser import configparser
from pymilvus import MilvusClient from pymilvus import MilvusClient # type: ignore
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return [] return []
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will delete all data from the Milvus collection.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Drop the collection and recreate it
if self._client.has_collection(self.namespace):
self._client.drop_collection(self.namespace)
# Recreate the collection
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
logger.info(
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
if not pm.is_installed("motor"): if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
from motor.motor_asyncio import ( from motor.motor_asyncio import ( # type: ignore
AsyncIOMotorClient, AsyncIOMotorClient,
AsyncIOMotorDatabase, AsyncIOMotorDatabase,
AsyncIOMotorCollection, AsyncIOMotorCollection,
) )
from pymongo.operations import SearchIndexModel from pymongo.operations import SearchIndexModel # type: ignore
from pymongo.errors import PyMongoError from pymongo.errors import PyMongoError # type: ignore
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
# Mongo handles persistence automatically # Mongo handles persistence automatically
pass pass
async def delete(self, ids: list[str]) -> None:
"""Delete documents with specified IDs
Args:
ids: List of document IDs to be deleted
"""
if not ids:
return
try:
result = await self._data.delete_many({"_id": {"$in": ids}})
logger.info(
f"Deleted {result.deleted_count} documents from {self.namespace}"
)
except PyMongoError as e:
logger.error(f"Error deleting documents from {self.namespace}: {e}")
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
bool: True if successful, False otherwise
"""
if not modes:
return False
try:
# Build regex pattern to match documents with the specified modes
pattern = f"^({'|'.join(modes)})_"
result = await self._data.delete_many({"_id": {"$regex": pattern}})
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
return True
except Exception as e:
logger.error(f"Error deleting cache by modes {modes}: {e}")
return False
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
result = await self._data.delete_many({})
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
@final @final
@dataclass @dataclass
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
# Mongo handles persistence automatically # Mongo handles persistence automatically
pass pass
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
result = await self._data.delete_many({})
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
@final @final
@dataclass @dataclass
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
logger.debug(f"Successfully deleted edges: {edges}") logger.debug(f"Successfully deleted edges: {edges}")
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
result = await self.collection.delete_many({})
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from graph {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping graph {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
@final @final
@dataclass @dataclass
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return [] return []
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection and recreating vector index.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
try:
# Delete all documents
result = await self._data.delete_many({})
deleted_count = result.deleted_count
# Recreate vector index
await self.create_vector_index_if_not_exists()
logger.info(
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped and vector index recreated",
}
except PyMongoError as e:
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
collection_names = await db.list_collection_names() collection_names = await db.list_collection_names()

View File

@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
return self._client return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def delete(self, ids: list[str]): async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs """Delete vectors with specified IDs
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args: Args:
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug( logger.debug(
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting entity {entity_name}: {e}") logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
try: try:
client = await self._get_client() client = await self._get_client()
storage = getattr(client, "_NanoVectorDB__storage") storage = getattr(client, "_NanoVectorDB__storage")
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
client = await self._get_client() client = await self._get_client()
return client.get(ids) return client.get(ids)
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will:
1. Remove the vector database storage file if it exists
2. Reinitialize the vector database client
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
This method is intended for use in scenarios where all data needs to be removed,
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# delete _client_file_name
if os.path.exists(self._client_file_name):
os.remove(self._client_file_name)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -1,9 +1,8 @@
import asyncio
import inspect import inspect
import os import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final, Optional from typing import Any, final
import numpy as np import numpy as np
import configparser import configparser
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
exceptions as neo4jExceptions, exceptions as neo4jExceptions,
AsyncDriver, AsyncDriver,
AsyncManagedTransaction, AsyncManagedTransaction,
GraphDatabase,
) )
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self._driver = None self._driver = None
self._driver_lock = asyncio.Lock()
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get( USERNAME = os.environ.get(
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
), ),
) )
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
) )
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
) )
# Try to connect to the database # Try to connect to the database and create it if it doesn't exist
with GraphDatabase.driver( for database in (DATABASE, None):
URI, self._DATABASE = database
auth=(USERNAME, PASSWORD), connected = False
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
) as _sync_driver:
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try: try:
with _sync_driver.session(database=database) as session: async with self._driver.session(database=database) as session:
try: try:
session.run("MATCH (n) RETURN n LIMIT 0") result = await session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {database} at {URI}") await result.consume() # Ensure result is consumed
connected = True logger.info(f"Connected to {database} at {URI}")
except neo4jExceptions.ServiceUnavailable as e: connected = True
logger.error( except neo4jExceptions.ServiceUnavailable as e:
f"{database} at {URI} is not available".capitalize() logger.error(
) f"{database} at {URI} is not available".capitalize()
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
) )
try: raise e
with _sync_driver.session() as session: except neo4jExceptions.AuthError as e:
session.run( logger.error(f"Authentication failed for {database} at {URI}")
f"CREATE DATABASE `{database}` IF NOT EXISTS" raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
) )
logger.info(f"{database} at {URI} created".capitalize()) if database is None:
connected = True logger.error(f"Failed to create {database} at {URI}")
except ( raise e
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if connected: if connected:
break # Create index for base nodes on entity_id if it doesn't exist
try:
async with self._driver.session(database=database) as session:
# Check if index exists first
check_query = """
CALL db.indexes() YIELD name, labelsOrTypes, properties
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
RETURN count(*) > 0 AS exists
"""
try:
check_result = await session.run(check_query)
record = await check_result.single()
await check_result.consume()
def __post_init__(self): index_exists = record and record.get("exists", False)
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self): if not index_exists:
# Create index only if it doesn't exist
result = await session.run(
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
)
await result.consume()
logger.info(
f"Created index for base nodes on entity_id in {database}"
)
except Exception:
# Fallback if db.indexes() is not supported in this Neo4j version
result = await session.run(
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
)
await result.consume()
except Exception as e:
logger.warning(f"Failed to create index: {str(e)}")
break
async def finalize(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits""" """Ensure driver is closed when context manager exits"""
await self.close() await self.finalize()
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Noe4J handles persistence automatically # Noe4J handles persistence automatically
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
raise raise
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier. """Get node by its label identifier, return only node properties
Args: Args:
node_id: The node label to look up node_id: The node label to look up
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.debug( logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}" f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
) )
# Return default edge properties when no edge found # Return None when no edge found
return { return None
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
finally: finally:
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
""" """
properties = node_data properties = node_data
entity_type = properties["entity_type"] entity_type = properties["entity_type"]
entity_id = properties["entity_id"]
if "entity_id" not in properties: if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field") raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = ( query = (
""" """
MERGE (n:base {entity_id: $properties.entity_id}) MERGE (n:base {entity_id: $entity_id})
SET n += $properties SET n += $properties
SET n:`%s` SET n:`%s`
""" """
% entity_type % entity_type
) )
result = await tx.run(query, properties=properties) result = await tx.run(
query, entity_id=node_id, properties=properties
)
logger.debug( logger.debug(
f"Upserted node with entity_id '{entity_id}' and properties: {properties}" f"Upserted node with entity_id '{node_id}' and properties: {properties}"
) )
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
min_degree: int = 0, max_nodes: int = MAX_GRAPH_NODES,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph, Defaults to 3
min_degree: Minimum degree of nodes to include. Defaults to 0 max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
inclusive: Do an inclusive search if true
Returns: Returns:
KnowledgeGraph: Complete connected subgraph for specified node KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
""" """
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
) as session: ) as session:
try: try:
if node_label == "*": if node_label == "*":
# First check total node count to determine if graph is truncated
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
try:
count_result = await session.run(count_query)
count_record = await count_result.single()
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
await count_result.consume()
# Run main query to get nodes with highest degree
main_query = """ main_query = """
MATCH (n) MATCH (n)
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree WITH n, COALESCE(count(r), 0) AS degree
WHERE degree >= $min_degree
ORDER BY degree DESC ORDER BY degree DESC
LIMIT $max_nodes LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes WITH collect({node: n}) AS filtered_nodes
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
RETURN filtered_nodes AS node_info, RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships collect(DISTINCT r) AS relationships
""" """
result_set = await session.run( result_set = None
main_query, try:
{"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree}, result_set = await session.run(
) main_query,
{"max_nodes": max_nodes},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
else: else:
# Main query uses partial matching # return await self._robust_fallback(node_label, max_depth, max_nodes)
main_query = """ # First try without limit to check if we need to truncate
full_query = """
MATCH (start) MATCH (start)
WHERE WHERE start.entity_id = $entity_id
CASE
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
ELSE start.entity_id = $entity_id
END
WITH start WITH start
CALL apoc.path.subgraphAll(start, { CALL apoc.path.subgraphAll(start, {
relationshipFilter: '', relationshipFilter: '',
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
bfs: true bfs: true
}) })
YIELD nodes, relationships YIELD nodes, relationships
WITH start, nodes, relationships WITH nodes, relationships, size(nodes) AS total_nodes
UNWIND nodes AS node UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-() WITH collect({node: node}) AS node_info, relationships, total_nodes
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships RETURN node_info, relationships, total_nodes
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
ORDER BY
CASE
WHEN node = start THEN 3
WHEN EXISTS((start)--(node)) THEN 2
ELSE 1
END DESC,
degree DESC
LIMIT $max_nodes
WITH collect({node: node}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
""" """
result_set = await session.run(
main_query,
{
"max_nodes": MAX_GRAPH_NODES,
"entity_id": node_label,
"inclusive": inclusive,
"max_depth": max_depth,
"min_degree": min_degree,
},
)
try: # Try to get full result
record = await result_set.single() full_result = None
try:
if record: full_result = await session.run(
# Handle nodes (compatible with multi-label cases) full_query,
for node_info in record["node_info"]: {
node = node_info["node"] "entity_id": node_label,
node_id = node.id "max_depth": max_depth,
if node_id not in seen_nodes: },
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
) )
finally: full_record = await full_result.single()
await result_set.consume() # Ensure result set is consumed
# If no record found, return empty KnowledgeGraph
if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}")
return result
# If record found, check node count
total_nodes = full_record["total_nodes"]
if total_nodes <= max_nodes:
# If node count is within limit, use full result directly
logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)"
)
record = full_record
else:
# If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True
logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
)
# Run limited query
limited_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
limit: $max_nodes,
bfs: true
})
YIELD nodes, relationships
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships
RETURN node_info, relationships
"""
result_set = None
try:
result_set = await session.run(
limited_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
finally:
if full_result:
await full_result.consume()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}") logger.warning(f"APOC plugin error: {str(e)}")
@@ -767,110 +837,28 @@ class Neo4JStorage(BaseGraphStorage):
logger.warning( logger.warning(
"Neo4j: falling back to basic Cypher recursive search..." "Neo4j: falling back to basic Cypher recursive search..."
) )
if inclusive: return await self._robust_fallback(node_label, max_depth, max_nodes)
logger.warning( else:
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching" logger.warning(
) "Neo4j: APOC plugin error with wildcard query, returning empty result"
return await self._robust_fallback(
node_label, max_depth, min_degree
) )
return result return result
async def _robust_fallback( async def _robust_fallback(
self, node_label: str, max_depth: int, min_degree: int = 0 self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Fallback implementation when APOC plugin is not available or incompatible. Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures. only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
""" """
from collections import deque
result = KnowledgeGraph() result = KnowledgeGraph()
visited_nodes = set() visited_nodes = set()
visited_edges = set() visited_edges = set()
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits
if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= MAX_GRAPH_NODES:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
return
# Check if node already visited
if node.id in visited_nodes:
return
# Get all edges and target nodes
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=node.id)
# Get all records and release database connection
records = await results.fetch(
1000
) # Max neighbour nodes we can handled
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
if current_depth > 1 and len(records) < min_degree:
return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=list(f"{target_id}"),
properties=dict(b_node.properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{node.id}",
target=f"{target_id}",
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else:
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
# Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
# Get the starting node's data # Get the starting node's data
async with self._driver.session( async with self._driver.session(
@@ -889,15 +877,129 @@ class Neo4JStorage(BaseGraphStorage):
# Create initial KnowledgeGraphNode # Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode( start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}", id=f"{node_record['n'].get('entity_id')}",
labels=list(f"{node_record['n'].get('entity_id')}"), labels=[node_record["n"].get("entity_id")],
properties=dict(node_record["n"].properties), properties=dict(node_record["n"]._properties),
) )
finally: finally:
await node_result.consume() # Ensure results are consumed await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node # Initialize queue for BFS with (node, edge, depth) tuples
await traverse(start_node, None, 0) # edge is None for the starting node
queue = deque([(start_node, None, 0)])
# True BFS implementation using a queue
while queue and len(visited_nodes) < max_nodes:
# Dequeue the next node to process
current_node, current_edge, current_depth = queue.popleft()
# Skip if already visited or exceeds max depth
if current_node.id in visited_nodes:
continue
if current_depth > max_depth:
logger.debug(
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
)
continue
# Add current node to result
result.nodes.append(current_node)
visited_nodes.add(current_node.id)
# Add edge to result if it exists and not already added
if current_edge and current_edge.id not in visited_edges:
result.edges.append(current_edge)
visited_edges.add(current_edge.id)
# Stop if we've reached the node limit
if len(visited_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
break
# Get all edges and target nodes for the current node (even at max_depth)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=current_node.id)
# Get all records and release database connection
records = await results.fetch(1000) # Max neighbor nodes we can handle
await results.consume() # Ensure results are consumed
# Process all neighbors - capture all edges but only queue unvisited nodes
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=[target_id],
properties=dict(b_node._properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{current_node.id}",
target=f"{target_id}",
properties=dict(rel),
)
# 对source_id和target_id进行排序确保(A,B)和(B,A)被视为同一条边
sorted_pair = tuple(sorted([current_node.id, target_id]))
# 检查是否已存在相同的边(考虑无向性)
if sorted_pair not in visited_edge_pairs:
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
if target_id in visited_nodes or (
target_id not in visited_nodes
and current_depth < max_depth
):
result.edges.append(target_edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
# Only add unvisited nodes to the queue for further expansion
if target_id not in visited_nodes:
# Only add to queue if we're not at max depth yet
if current_depth < max_depth:
# Add node to queue with incremented depth
# Edge is already added to result, so we pass None as edge
queue.append((target_node, None, current_depth + 1))
else:
# At max depth, we've already added the edge but we don't add the node
# This prevents adding nodes beyond max_depth to the result
logger.debug(
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
)
else:
# If target node already exists in result, we don't need to add it again
logger.debug(
f"Node {target_id} already visited, edge added but node not queued"
)
else:
logger.warning(
f"Skipping edge {edge_id} due to missing entity_id on target node"
)
logger.info(
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result return result
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
# Method 2: Query compatible with older versions # Method 2: Query compatible with older versions
query = """ query = """
MATCH (n) MATCH (n:base)
WHERE n.entity_id IS NOT NULL WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label RETURN DISTINCT n.entity_id AS label
ORDER BY label ORDER BY label
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This method will delete all nodes and relationships in the Neo4j database.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._driver.session(database=self._DATABASE) as session:
# Delete all nodes and relationships
query = "MATCH (n) DETACH DELETE n"
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
logger.info(
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
) )
nx.write_graphml(graph, file_name) nx.write_graphml(graph, file_name)
# TODOdeprecated, remove later
@staticmethod @staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph: def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
return None return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph() graph = await self._get_graph()
graph.add_node(node_id, **node_data) graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph() graph = await self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data) graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
graph = await self._get_graph() graph = await self._get_graph()
if graph.has_node(node_id): if graph.has_node(node_id):
graph.remove_node(node_id) graph.remove_node(node_id)
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
else: else:
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.warning(f"Node {node_id} not found in the graph for deletion.")
# TODO: NOT USED
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
async def remove_nodes(self, nodes: list[str]): async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes """Delete multiple nodes
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args: Args:
nodes: List of node IDs to be deleted nodes: List of node IDs to be deleted
""" """
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
async def remove_edges(self, edges: list[tuple[str, str]]): async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges """Delete multiple edges
Importance notes:
1. Changes will be persisted to disk during the next index_done_callback
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
Args: Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple edges: List of edges to be deleted, each edge is a (source, target) tuple
""" """
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
min_degree: int = 0, max_nodes: int = MAX_GRAPH_NODES,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph, Defaults to 3
min_degree: Minimum degree of nodes to include. Defaults to 0 max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
inclusive: Do an inclusive search if true
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
""" """
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
graph = await self._get_graph() graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes result = KnowledgeGraph()
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # Get degrees of all nodes
subgraph = ( degrees = dict(graph.degree())
graph.copy() # Sort nodes by degree in descending order and take top max_nodes
) # Create a copy to avoid modifying the original graph sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
# Check if graph is truncated
if len(sorted_nodes) > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
)
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
# Create subgraph with the highest degree nodes
subgraph = graph.subgraph(limited_nodes)
else: else:
# Find nodes with matching node id based on search_mode # Check if node exists
nodes_to_explore = [] if node_label not in graph:
for n, attr in graph.nodes(data=True): logger.warning(f"Node {node_label} not found in the graph")
node_str = str(n) return KnowledgeGraph() # Return empty graph
if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore: # Use BFS to get nodes
logger.warning(f"No nodes found with label {node_label}") bfs_nodes = []
return result visited = set()
queue = [(node_label, 0)] # (node, depth) tuple
# Get subgraph using ego_graph from all matching nodes # Breadth-first search
combined_subgraph = nx.Graph() while queue and len(bfs_nodes) < max_nodes:
for start_node in nodes_to_explore: current, depth = queue.pop(0)
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) if current not in visited:
combined_subgraph = nx.compose(combined_subgraph, node_subgraph) visited.add(current)
bfs_nodes.append(current)
# Get start nodes and direct connected nodes # Only explore neighbors if we haven't reached max_depth
if nodes_to_explore: if depth < max_depth:
start_nodes = set(nodes_to_explore) # Add neighbor nodes to queue with incremented depth
# Get nodes directly connected to all start nodes neighbors = list(graph.neighbors(current))
for start_node in start_nodes: queue.extend(
direct_connected_nodes.update( [(n, depth + 1) for n in neighbors if n not in visited]
combined_subgraph.neighbors(start_node) )
)
# Remove start nodes from directly connected nodes (avoid duplicates) # Check if graph is truncated - if we still have nodes in the queue
direct_connected_nodes -= start_nodes # and we've reached max_nodes, then the graph is truncated
if queue and len(bfs_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
)
subgraph = combined_subgraph # Create subgraph with BFS discovered nodes
subgraph = graph.subgraph(bfs_nodes)
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
if node in start_nodes:
priority = 2
elif node in direct_connected_nodes:
priority = 1
else:
priority = 0
return (priority, degree)
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
:MAX_GRAPH_NODES
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
)
# Add nodes to result # Add nodes to result
seen_nodes = set()
seen_edges = set()
for node in subgraph.nodes(): for node in subgraph.nodes():
if str(node) in seen_nodes: if str(node) in seen_nodes:
continue continue
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
for edge in subgraph.edges(): for edge in subgraph.edges():
source, target = edge source, target = edge
# Esure unique edge_id for undirect graph # Esure unique edge_id for undirect graph
if source > target: if str(source) > str(target):
source, target = target, source source, target = target, source
edge_id = f"{source}-{target}" edge_id = f"{source}-{target}"
if edge_id in seen_edges: if edge_id in seen_edges:
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
return False # Return error return False # Return error
return True return True
async def drop(self) -> dict[str, str]:
"""Drop all graph data from storage and clean up resources
This method will:
1. Remove the graph storage file if it exists
2. Reset the graph to an empty state
3. Update flags to notify other processes
4. Changes is persisted to disk immediately
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._storage_lock:
# delete _client_file_name
if os.path.exists(self._graphml_xml_file):
os.remove(self._graphml_xml_file)
self._graph = nx.Graph()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -8,17 +8,15 @@ import uuid
from ..utils import logger from ..utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import configparser import configparser
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("qdrant-client"): if not pm.is_installed("qdrant-client"):
pm.install("qdrant-client") pm.install("qdrant-client")
from qdrant_client import QdrantClient, models from qdrant_client import QdrantClient, models # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
def compute_mdhash_id_for_qdrant( def compute_mdhash_id_for_qdrant(
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error searching for prefix '{prefix}': {e}") logger.error(f"Error searching for prefix '{prefix}': {e}")
return [] return []
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
try:
# Convert to Qdrant compatible ID
qdrant_id = compute_mdhash_id_for_qdrant(id)
# Retrieve the point by ID
result = self._client.retrieve(
collection_name=self.namespace,
ids=[qdrant_id],
with_payload=True,
)
if not result:
return None
return result[0].payload
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
try:
# Convert to Qdrant compatible IDs
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Retrieve the points by IDs
results = self._client.retrieve(
collection_name=self.namespace,
ids=qdrant_ids,
with_payload=True,
)
return [point.payload for point in results]
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
async def drop(self) -> dict[str, str]:
"""Drop all vector data from storage and clean up resources
This method will delete all data from the Qdrant collection.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
# Delete the collection and recreate it
if self._client.collection_exists(self.namespace):
self._client.delete_collection(self.namespace)
# Recreate the collection
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
),
)
logger.info(
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -12,6 +12,7 @@ if not pm.is_installed("redis"):
from redis.asyncio import Redis, ConnectionPool from redis.asyncio import Redis, ConnectionPool
from redis.exceptions import RedisError, ConnectionError from redis.exceptions import RedisError, ConnectionError
from lightrag.utils import logger, compute_mdhash_id from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseKVStorage from lightrag.base import BaseKVStorage
import json import json
@@ -121,7 +122,11 @@ class RedisKVStorage(BaseKVStorage):
except json.JSONEncodeError as e: except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}") logger.error(f"JSON encode error during upsert: {e}")
raise raise
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs""" """Delete entries with specified IDs"""
if not ids: if not ids:
@@ -138,71 +143,52 @@ class RedisKVStorage(BaseKVStorage):
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
) )
async def delete_entity(self, entity_name: str) -> None: async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete an entity by name""" """Delete specific records from storage by by cache mode
Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis
Args:
modes (list[str]): List of cache mode to be drop from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") await self.delete(modes)
logger.debug( return True
f"Attempting to delete entity {entity_name} with ID {entity_id}" except Exception:
) return False
async with self._get_redis_connection() as redis: async def drop(self) -> dict[str, str]:
result = await redis.delete(f"{self.namespace}:{entity_id}") """Drop the storage by removing all keys under the current namespace.
if result: Returns:
logger.debug(f"Successfully deleted entity {entity_name}") dict[str, str]: Status of the operation with keys 'status' and 'message'
else: """
logger.debug(f"Entity {entity_name} not found in storage") async with self._get_redis_connection() as redis:
except Exception as e: try:
logger.error(f"Error deleting entity {entity_name}: {e}") keys = await redis.keys(f"{self.namespace}:*")
async def delete_entity_relation(self, entity_name: str) -> None: if keys:
"""Delete all relations associated with an entity"""
try:
async with self._get_redis_connection() as redis:
cursor = 0
relation_keys = []
pattern = f"{self.namespace}:*"
while True:
cursor, keys = await redis.scan(cursor, match=pattern)
# Process keys in batches
pipe = redis.pipeline() pipe = redis.pipeline()
for key in keys: for key in keys:
pipe.get(key) pipe.delete(key)
values = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results)
for key, value in zip(keys, values):
if value:
try:
data = json.loads(value)
if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key)
except json.JSONDecodeError:
logger.warning(f"Invalid JSON in key {key}")
continue
if cursor == 0: logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
break return {"status": "success", "message": f"{deleted_count} keys dropped"}
# Delete relations in batches
if relation_keys:
# Delete in chunks to avoid too many arguments
chunk_size = 1000
for i in range(0, len(relation_keys), chunk_size):
chunk = relation_keys[i:i + chunk_size]
deleted = await redis.delete(*chunk)
logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
else: else:
logger.debug(f"No relations found for entity {entity_name}") logger.info(f"No keys found to drop in {self.namespace}")
return {"status": "success", "message": "no keys to drop"}
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass

View File

@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"): if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy") pm.install("sqlalchemy")
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text # type: ignore
class TiDB: class TiDB:
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
# Ti handles persistence automatically # Ti handles persistence automatically
pass pass
async def delete(self, ids: list[str]) -> None:
"""Delete records with specified IDs from the storage.
Args:
ids: List of record IDs to be deleted
"""
if not ids:
return
try:
table_name = namespace_to_table_name(self.namespace)
id_field = namespace_to_id(self.namespace)
if not table_name or not id_field:
logger.error(f"Unknown namespace for deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.info(
f"Successfully deleted {len(ids)} records from {self.namespace}"
)
except Exception as e:
logger.error(f"Error deleting records from {self.namespace}: {e}")
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
bool: True if successful, False otherwise
"""
if not modes:
return False
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return False
if table_name != "LIGHTRAG_LLM_CACHE":
return False
# 构建MySQL风格的IN查询
modes_list = ", ".join([f"'{mode}'" for mode in modes])
sql = f"""
DELETE FROM {table_name}
WHERE workspace = :workspace
AND mode IN ({modes_list})
"""
logger.info(f"Deleting cache by modes: {modes}")
await self.db.execute(sql, {"workspace": self.db.workspace})
return True
except Exception as e:
logger.error(f"Error deleting cache by modes {modes}: {e}")
return False
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@final @final
@dataclass @dataclass
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs from the storage.
Args:
ids: List of vector IDs to be deleted
"""
if not ids:
return
table_name = namespace_to_table_name(self.namespace)
id_field = namespace_to_id(self.namespace)
if not table_name or not id_field:
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
raise NotImplementedError """Delete an entity by its name from the vector storage.
Args:
entity_name: The name of the entity to delete
"""
try:
# Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace AND name = :entity_name"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
raise NotImplementedError """Delete all relations associated with an entity.
Args:
entity_name: The name of the entity whose relations should be deleted
"""
try:
# Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Ti handles persistence automatically # Ti handles persistence automatically
pass pass
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]: async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix. """Search for records with IDs starting with a specific prefix.
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
# Ti handles persistence automatically # Ti handles persistence automatically
pass pass
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
drop_sql = """
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
"""
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""Delete a node and all its related edges """Delete a node and all its related edges
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
FROM LIGHTRAG_DOC_CHUNKS FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
""", """,
# Drop tables
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
} }

View File

@@ -13,7 +13,6 @@ import pandas as pd
from lightrag.kg import ( from lightrag.kg import (
STORAGE_ENV_REQUIREMENTS,
STORAGES, STORAGES,
verify_storage_implementation, verify_storage_implementation,
) )
@@ -230,6 +229,7 @@ class LightRAG:
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage.""" """Additional parameters for vector database storage."""
# TODOdeprecated, remove in the future, use WORKSPACE instead
namespace_prefix: str = field(default="") namespace_prefix: str = field(default="")
"""Prefix for namespacing stored data across different environments.""" """Prefix for namespacing stored data across different environments."""
@@ -510,36 +510,22 @@ class LightRAG:
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
min_degree: int = 0, max_nodes: int = 1000,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Get knowledge graph for a given label """Get knowledge graph for a given label
Args: Args:
node_label (str): Label to get knowledge graph for node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph max_depth (int): Maximum depth of graph
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0. max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
Returns: Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges KnowledgeGraph: Knowledge graph containing nodes and edges
""" """
# get params supported by get_knowledge_graph of specified storage
import inspect
storage_params = inspect.signature( return await self.chunk_entity_relation_graph.get_knowledge_graph(
self.chunk_entity_relation_graph.get_knowledge_graph node_label, max_depth, max_nodes
).parameters )
kwargs = {"node_label": node_label, "max_depth": max_depth}
if "min_degree" in storage_params and min_degree > 0:
kwargs["min_degree"] = min_degree
if "inclusive" in storage_params:
kwargs["inclusive"] = inclusive
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name] import_path = STORAGES[storage_name]
@@ -1449,6 +1435,7 @@ class LightRAG:
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.adelete_by_entity(entity_name)) return loop.run_until_complete(self.adelete_by_entity(entity_name))
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def adelete_by_entity(self, entity_name: str) -> None: async def adelete_by_entity(self, entity_name: str) -> None:
try: try:
await self.entities_vdb.delete_entity(entity_name) await self.entities_vdb.delete_entity(entity_name)
@@ -1486,6 +1473,7 @@ class LightRAG:
self.adelete_by_relation(source_entity, target_entity) self.adelete_by_relation(source_entity, target_entity)
) )
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
"""Asynchronously delete a relation between two entities. """Asynchronously delete a relation between two entities.
@@ -1494,6 +1482,7 @@ class LightRAG:
target_entity: Name of the target entity target_entity: Name of the target entity
""" """
try: try:
# TODO: check if has_edge function works on reverse relation
# Check if the relation exists # Check if the relation exists
edge_exists = await self.chunk_entity_relation_graph.has_edge( edge_exists = await self.chunk_entity_relation_graph.has_edge(
source_entity, target_entity source_entity, target_entity
@@ -1554,6 +1543,7 @@ class LightRAG:
""" """
return await self.doc_status.get_docs_by_status(status) return await self.doc_status.get_docs_by_status(status)
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def adelete_by_doc_id(self, doc_id: str) -> None: async def adelete_by_doc_id(self, doc_id: str) -> None:
"""Delete a document and all its related data """Delete a document and all its related data
@@ -1586,6 +1576,8 @@ class LightRAG:
chunk_ids = set(related_chunks.keys()) chunk_ids = set(related_chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete") logger.debug(f"Found {len(chunk_ids)} chunks to delete")
# TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
# 3. Before deleting, check the related entities and relationships for these chunks # 3. Before deleting, check the related entities and relationships for these chunks
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
# Check entities # Check entities
@@ -1857,24 +1849,6 @@ class LightRAG:
return result return result
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
async def aclear_cache(self, modes: list[str] | None = None) -> None: async def aclear_cache(self, modes: list[str] | None = None) -> None:
"""Clear cache data from the LLM response cache storage. """Clear cache data from the LLM response cache storage.
@@ -1906,12 +1880,18 @@ class LightRAG:
try: try:
# Reset the cache storage for specified mode # Reset the cache storage for specified mode
if modes: if modes:
await self.llm_response_cache.delete(modes) success = await self.llm_response_cache.drop_cache_by_modes(modes)
logger.info(f"Cleared cache for modes: {modes}") if success:
logger.info(f"Cleared cache for modes: {modes}")
else:
logger.warning(f"Failed to clear cache for modes: {modes}")
else: else:
# Clear all modes # Clear all modes
await self.llm_response_cache.delete(valid_modes) success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
logger.info("Cleared all cache") if success:
logger.info("Cleared all cache")
else:
logger.warning("Failed to clear all cache")
await self.llm_response_cache.index_done_callback() await self.llm_response_cache.index_done_callback()
@@ -1922,6 +1902,7 @@ class LightRAG:
"""Synchronous version of aclear_cache.""" """Synchronous version of aclear_cache."""
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes)) return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def aedit_entity( async def aedit_entity(
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -2134,6 +2115,7 @@ class LightRAG:
] ]
) )
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def aedit_relation( async def aedit_relation(
self, source_entity: str, target_entity: str, updated_data: dict[str, Any] self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -2448,6 +2430,7 @@ class LightRAG:
self.acreate_relation(source_entity, target_entity, relation_data) self.acreate_relation(source_entity, target_entity, relation_data)
) )
# TODO: Lock all KG relative DB to esure consistency across multiple processes
async def amerge_entities( async def amerge_entities(
self, self,
source_entities: list[str], source_entities: list[str],

View File

@@ -44,6 +44,47 @@ class InvalidResponseError(Exception):
pass pass
def create_openai_async_client(
api_key: str | None = None,
base_url: str | None = None,
client_configs: dict[str, Any] = None,
) -> AsyncOpenAI:
"""Create an AsyncOpenAI client with the given configuration.
Args:
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
Returns:
An AsyncOpenAI client instance.
"""
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
if client_configs is None:
client_configs = {}
# Create a merged config dict with precedence: explicit params > client_configs > defaults
merged_configs = {
**client_configs,
"default_headers": default_headers,
"api_key": api_key,
}
if base_url is not None:
merged_configs["base_url"] = base_url
return AsyncOpenAI(**merged_configs)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -61,29 +102,52 @@ async def openai_complete_if_cache(
token_tracker: Any | None = None, token_tracker: Any | None = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Complete a prompt using OpenAI's API with caching support.
Args:
model: The OpenAI model to use.
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
These will be passed to the client constructor but will be overridden by
explicit parameters (api_key, base_url).
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
Returns:
The completed text or an async iterator of text chunks if streaming.
Raises:
InvalidResponseError: If the response from OpenAI is invalid or empty.
APIConnectionError: If there is a connection error with the OpenAI API.
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
if history_messages is None: if history_messages is None:
history_messages = [] history_messages = []
if not api_key:
api_key = os.environ["OPENAI_API_KEY"]
default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
# Set openai logger level to INFO when VERBOSE_DEBUG is off # Set openai logger level to INFO when VERBOSE_DEBUG is off
if not VERBOSE_DEBUG and logger.level == logging.DEBUG: if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("openai").setLevel(logging.INFO) logging.getLogger("openai").setLevel(logging.INFO)
openai_async_client = ( # Extract client configuration options
AsyncOpenAI(default_headers=default_headers, api_key=api_key) client_configs = kwargs.pop("openai_client_configs", {})
if base_url is None
else AsyncOpenAI( # Create the OpenAI client
base_url=base_url, default_headers=default_headers, api_key=api_key openai_async_client = create_openai_async_client(
) api_key=api_key, base_url=base_url, client_configs=client_configs
) )
# Remove special kwargs that shouldn't be passed to OpenAI
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)
# Prepare messages
messages: list[dict[str, Any]] = [] messages: list[dict[str, Any]] = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
@@ -272,21 +336,32 @@ async def openai_embed(
model: str = "text-embedding-3-small", model: str = "text-embedding-3-small",
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
client_configs: dict[str, Any] = None,
) -> np.ndarray: ) -> np.ndarray:
if not api_key: """Generate embeddings for a list of texts using OpenAI's API.
api_key = os.environ["OPENAI_API_KEY"]
default_headers = { Args:
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", texts: List of texts to embed.
"Content-Type": "application/json", model: The OpenAI embedding model to use.
} base_url: Optional base URL for the OpenAI API.
openai_async_client = ( api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
AsyncOpenAI(default_headers=default_headers, api_key=api_key) client_configs: Additional configuration options for the AsyncOpenAI client.
if base_url is None These will override any default configurations but will be overridden by
else AsyncOpenAI( explicit parameters (api_key, base_url).
base_url=base_url, default_headers=default_headers, api_key=api_key
) Returns:
A numpy array of embeddings, one per input text.
Raises:
APIConnectionError: If there is a connection error with the OpenAI API.
RateLimitError: If the OpenAI API rate limit is exceeded.
APITimeoutError: If the OpenAI API request times out.
"""
# Create the OpenAI client
openai_async_client = create_openai_async_client(
api_key=api_key, base_url=base_url, client_configs=client_configs
) )
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"
) )

View File

@@ -26,7 +26,6 @@ from .utils import (
CacheData, CacheData,
statistic_data, statistic_data,
get_conversation_turns, get_conversation_turns,
verbose_debug,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -442,6 +441,13 @@ async def extract_entities(
processed_chunks = 0 processed_chunks = 0
total_chunks = len(ordered_chunks) total_chunks = len(ordered_chunks)
total_entities_count = 0
total_relations_count = 0
# Get lock manager from shared storage
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
async def _user_llm_func_with_cache( async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None input_text: str, history_messages: list[dict[str, str]] = None
@@ -540,7 +546,7 @@ async def extract_entities(
chunk_key_dp (tuple[str, TextChunkSchema]): chunk_key_dp (tuple[str, TextChunkSchema]):
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
""" """
nonlocal processed_chunks nonlocal processed_chunks, total_entities_count, total_relations_count
chunk_key = chunk_key_dp[0] chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1] chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"] content = chunk_dp["content"]
@@ -598,102 +604,74 @@ async def extract_entities(
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
return dict(maybe_nodes), dict(maybe_edges)
tasks = [_process_single_content(c) for c in ordered_chunks] # Use graph database lock to ensure atomic merges and updates
results = await asyncio.gather(*tasks) chunk_entities_data = []
chunk_relationships_data = []
maybe_nodes = defaultdict(list) async with graph_db_lock:
maybe_edges = defaultdict(list) # Process and update entities
for m_nodes, m_edges in results: for entity_name, entities in maybe_nodes.items():
for k, v in m_nodes.items(): entity_data = await _merge_nodes_then_upsert(
maybe_nodes[k].extend(v) entity_name, entities, knowledge_graph_inst, global_config
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
# Ensure that nodes and edges are merged and upserted atomically
async with graph_db_lock:
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
_merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
) )
for k, v in maybe_edges.items() chunk_entities_data.append(entity_data)
]
)
if not (all_entities_data or all_relationships_data): # Process and update relationships
log_message = "Didn't extract any entities and relationships." for edge_key, edges in maybe_edges.items():
logger.info(log_message) # Ensure edge direction consistency
if pipeline_status is not None: sorted_edge_key = tuple(sorted(edge_key))
async with pipeline_status_lock: edge_data = await _merge_edges_then_upsert(
pipeline_status["latest_message"] = log_message sorted_edge_key[0],
pipeline_status["history_messages"].append(log_message) sorted_edge_key[1],
return edges,
knowledge_graph_inst,
global_config,
)
chunk_relationships_data.append(edge_data)
if not all_entities_data: # Update vector database (within the same lock to ensure atomicity)
log_message = "Didn't extract any entities" if entity_vdb is not None and chunk_entities_data:
logger.info(log_message) data_for_vdb = {
if pipeline_status is not None: compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
async with pipeline_status_lock: "entity_name": dp["entity_name"],
pipeline_status["latest_message"] = log_message "entity_type": dp["entity_type"],
pipeline_status["history_messages"].append(log_message) "content": f"{dp['entity_name']}\n{dp['description']}",
if not all_relationships_data: "source_id": dp["source_id"],
log_message = "Didn't extract any relationships" "file_path": dp.get("file_path", "unknown_source"),
logger.info(log_message) }
if pipeline_status is not None: for dp in chunk_entities_data
async with pipeline_status_lock: }
pipeline_status["latest_message"] = log_message await entity_vdb.upsert(data_for_vdb)
pipeline_status["history_messages"].append(log_message)
log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)" if relationships_vdb is not None and chunk_relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in chunk_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
# Update counters
total_entities_count += len(chunk_entities_data)
total_relations_count += len(chunk_relationships_data)
# Handle all chunks in parallel
tasks = [_process_single_content(c) for c in ordered_chunks]
await asyncio.gather(*tasks)
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
logger.info(log_message) logger.info(log_message)
if pipeline_status is not None: if pipeline_status is not None:
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
verbose_debug(
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
)
verbose_debug(f"New relationships:{all_relationships_data}")
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
if relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
async def kg_query( async def kg_query(
@@ -720,8 +698,7 @@ async def kg_query(
if cached_response is not None: if cached_response is not None:
return cached_response return cached_response
# Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await get_keywords_from_query(
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv
) )
@@ -817,6 +794,38 @@ async def kg_query(
return response return response
async def get_keywords_from_query(
query: str,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
Retrieves high-level and low-level keywords for RAG operations.
This function checks if keywords are already provided in query parameters,
and if not, extracts them from the query text using LLM.
Args:
query: The user's query text
query_param: Query parameters that may contain pre-defined keywords
global_config: Global configuration dictionary
hashing_kv: Optional key-value storage for caching results
Returns:
A tuple containing (high_level_keywords, low_level_keywords)
"""
# Check if pre-defined keywords are already provided
if query_param.hl_keywords or query_param.ll_keywords:
return query_param.hl_keywords, query_param.ll_keywords
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv
)
return hl_keywords, ll_keywords
async def extract_keywords_only( async def extract_keywords_only(
text: str, text: str,
param: QueryParam, param: QueryParam,
@@ -957,8 +966,7 @@ async def mix_kg_vector_query(
# 2. Execute knowledge graph and vector searches in parallel # 2. Execute knowledge graph and vector searches in parallel
async def get_kg_context(): async def get_kg_context():
try: try:
# Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await get_keywords_from_query(
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv
) )
@@ -1339,7 +1347,9 @@ async def _get_node_data(
text_units_section_list = [["id", "content", "file_path"]] text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units): for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"], t["file_path"]]) text_units_section_list.append(
[i, t["content"], t.get("file_path", "unknown_source")]
)
text_units_context = list_of_list_to_csv(text_units_section_list) text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context, relations_context, text_units_context return entities_context, relations_context, text_units_context
@@ -2043,16 +2053,13 @@ async def query_with_keywords(
Query response or async iterator Query response or async iterator
""" """
# Extract keywords # Extract keywords
hl_keywords, ll_keywords = await extract_keywords_only( hl_keywords, ll_keywords = await get_keywords_from_query(
text=query, query=query,
param=param, query_param=param,
global_config=global_config, global_config=global_config,
hashing_kv=hashing_kv, hashing_kv=hashing_kv,
) )
param.hl_keywords = hl_keywords
param.ll_keywords = ll_keywords
# Create a new string with the prompt and the keywords # Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords) ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords) hl_keywords_str = ", ".join(hl_keywords)

View File

@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
nodes: list[KnowledgeGraphNode] = [] nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = [] edges: list[KnowledgeGraphEdge] = []
is_truncated: bool = False

View File

@@ -3,12 +3,13 @@ import ThemeProvider from '@/components/ThemeProvider'
import TabVisibilityProvider from '@/contexts/TabVisibilityProvider' import TabVisibilityProvider from '@/contexts/TabVisibilityProvider'
import ApiKeyAlert from '@/components/ApiKeyAlert' import ApiKeyAlert from '@/components/ApiKeyAlert'
import StatusIndicator from '@/components/status/StatusIndicator' import StatusIndicator from '@/components/status/StatusIndicator'
import { healthCheckInterval } from '@/lib/constants' import { healthCheckInterval, SiteInfo, webuiPrefix } from '@/lib/constants'
import { useBackendState, useAuthStore } from '@/stores/state' import { useBackendState, useAuthStore } from '@/stores/state'
import { useSettingsStore } from '@/stores/settings' import { useSettingsStore } from '@/stores/settings'
import { getAuthStatus } from '@/api/lightrag' import { getAuthStatus } from '@/api/lightrag'
import SiteHeader from '@/features/SiteHeader' import SiteHeader from '@/features/SiteHeader'
import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag' import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag'
import { ZapIcon } from 'lucide-react'
import GraphViewer from '@/features/GraphViewer' import GraphViewer from '@/features/GraphViewer'
import DocumentManager from '@/features/DocumentManager' import DocumentManager from '@/features/DocumentManager'
@@ -22,6 +23,7 @@ function App() {
const enableHealthCheck = useSettingsStore.use.enableHealthCheck() const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
const currentTab = useSettingsStore.use.currentTab() const currentTab = useSettingsStore.use.currentTab()
const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false) const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false)
const [initializing, setInitializing] = useState(true) // Add initializing state
const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode
const handleApiKeyAlertOpenChange = useCallback((open: boolean) => { const handleApiKeyAlertOpenChange = useCallback((open: boolean) => {
@@ -55,29 +57,48 @@ function App() {
// Check if version info was already obtained in login page // Check if version info was already obtained in login page
const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true'; const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true';
if (versionCheckedFromLogin) return; if (versionCheckedFromLogin) {
setInitializing(false); // Skip initialization if already checked
// Get version info return;
const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); }
if (!token) return;
try { try {
setInitializing(true); // Start initialization
// Get version info
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
const status = await getAuthStatus(); const status = await getAuthStatus();
if (status.core_version || status.api_version) {
// If auth is not configured and a new token is returned, use the new token
if (!status.auth_configured && status.access_token) {
useAuthStore.getState().login(
status.access_token, // Use the new token
true, // Guest mode
status.core_version,
status.api_version,
status.webui_title || null,
status.webui_description || null
);
} else if (token && (status.core_version || status.api_version || status.webui_title || status.webui_description)) {
// Otherwise use the old token (if it exists)
const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode; const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode;
// Update version info while maintaining login state
useAuthStore.getState().login( useAuthStore.getState().login(
token, token,
isGuestMode, isGuestMode,
status.core_version, status.core_version,
status.api_version status.api_version,
status.webui_title || null,
status.webui_description || null
); );
// Set flag to indicate version info has been checked
sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true');
} }
// Set flag to indicate version info has been checked
sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true');
} catch (error) { } catch (error) {
console.error('Failed to get version info:', error); console.error('Failed to get version info:', error);
} finally {
// Ensure initializing is set to false even if there's an error
setInitializing(false);
} }
}; };
@@ -101,31 +122,63 @@ function App() {
return ( return (
<ThemeProvider> <ThemeProvider>
<TabVisibilityProvider> <TabVisibilityProvider>
<main className="flex h-screen w-screen overflow-hidden"> {initializing ? (
<Tabs // Loading state while initializing with simplified header
defaultValue={currentTab} <div className="flex h-screen w-screen flex-col">
className="!m-0 flex grow flex-col !p-0 overflow-hidden" {/* Simplified header during initialization - matches SiteHeader structure */}
onValueChange={handleTabChange} <header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
> <div className="min-w-[200px] w-auto flex items-center">
<SiteHeader /> <a href={webuiPrefix} className="flex items-center gap-2">
<div className="relative grow"> <ZapIcon className="size-4 text-emerald-400" aria-hidden="true" />
<TabsContent value="documents" className="absolute top-0 right-0 bottom-0 left-0 overflow-auto"> <span className="font-bold md:inline-block">{SiteInfo.name}</span>
<DocumentManager /> </a>
</TabsContent> </div>
<TabsContent value="knowledge-graph" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
<GraphViewer /> {/* Empty middle section to maintain layout */}
</TabsContent> <div className="flex h-10 flex-1 items-center justify-center">
<TabsContent value="retrieval" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden"> </div>
<RetrievalTesting />
</TabsContent> {/* Empty right section to maintain layout */}
<TabsContent value="api" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden"> <nav className="w-[200px] flex items-center justify-end">
<ApiSite /> </nav>
</TabsContent> </header>
{/* Loading indicator in content area */}
<div className="flex flex-1 items-center justify-center">
<div className="text-center">
<div className="mb-2 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"></div>
<p>Initializing...</p>
</div>
</div> </div>
</Tabs> </div>
{enableHealthCheck && <StatusIndicator />} ) : (
<ApiKeyAlert open={apiKeyAlertOpen} onOpenChange={handleApiKeyAlertOpenChange} /> // Main content after initialization
</main> <main className="flex h-screen w-screen overflow-hidden">
<Tabs
defaultValue={currentTab}
className="!m-0 flex grow flex-col !p-0 overflow-hidden"
onValueChange={handleTabChange}
>
<SiteHeader />
<div className="relative grow">
<TabsContent value="documents" className="absolute top-0 right-0 bottom-0 left-0 overflow-auto">
<DocumentManager />
</TabsContent>
<TabsContent value="knowledge-graph" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
<GraphViewer />
</TabsContent>
<TabsContent value="retrieval" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
<RetrievalTesting />
</TabsContent>
<TabsContent value="api" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
<ApiSite />
</TabsContent>
</div>
</Tabs>
{enableHealthCheck && <StatusIndicator />}
<ApiKeyAlert open={apiKeyAlertOpen} onOpenChange={handleApiKeyAlertOpenChange} />
</main>
)}
</TabVisibilityProvider> </TabVisibilityProvider>
</ThemeProvider> </ThemeProvider>
) )

View File

@@ -80,7 +80,12 @@ const AppRouter = () => {
<ThemeProvider> <ThemeProvider>
<Router> <Router>
<AppContent /> <AppContent />
<Toaster position="bottom-center" /> <Toaster
position="bottom-center"
theme="system"
closeButton
richColors
/>
</Router> </Router>
</ThemeProvider> </ThemeProvider>
) )

View File

@@ -46,6 +46,8 @@ export type LightragStatus = {
api_version?: string api_version?: string
auth_mode?: 'enabled' | 'disabled' auth_mode?: 'enabled' | 'disabled'
pipeline_busy: boolean pipeline_busy: boolean
webui_title?: string
webui_description?: string
} }
export type LightragDocumentsScanProgress = { export type LightragDocumentsScanProgress = {
@@ -140,6 +142,8 @@ export type AuthStatusResponse = {
message?: string message?: string
core_version?: string core_version?: string
api_version?: string api_version?: string
webui_title?: string
webui_description?: string
} }
export type PipelineStatusResponse = { export type PipelineStatusResponse = {
@@ -163,6 +167,8 @@ export type LoginResponse = {
message?: string // Optional message message?: string // Optional message
core_version?: string core_version?: string
api_version?: string api_version?: string
webui_title?: string
webui_description?: string
} }
export const InvalidApiKeyError = 'Invalid API Key' export const InvalidApiKeyError = 'Invalid API Key'
@@ -221,9 +227,9 @@ axiosInstance.interceptors.response.use(
export const queryGraphs = async ( export const queryGraphs = async (
label: string, label: string,
maxDepth: number, maxDepth: number,
minDegree: number maxNodes: number
): Promise<LightragGraphType> => { ): Promise<LightragGraphType> => {
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`) const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&max_nodes=${maxNodes}`)
return response.data return response.data
} }
@@ -382,6 +388,14 @@ export const clearDocuments = async (): Promise<DocActionResponse> => {
return response.data return response.data
} }
export const clearCache = async (modes?: string[]): Promise<{
status: 'success' | 'fail'
message: string
}> => {
const response = await axiosInstance.post('/documents/clear_cache', { modes })
return response.data
}
export const getAuthStatus = async (): Promise<AuthStatusResponse> => { export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
try { try {
// Add a timeout to the request to prevent hanging // Add a timeout to the request to prevent hanging

View File

@@ -1,4 +1,4 @@
import { useState, useCallback } from 'react' import { useState, useCallback, useEffect } from 'react'
import Button from '@/components/ui/Button' import Button from '@/components/ui/Button'
import { import {
Dialog, Dialog,
@@ -6,32 +6,88 @@ import {
DialogDescription, DialogDescription,
DialogHeader, DialogHeader,
DialogTitle, DialogTitle,
DialogTrigger DialogTrigger,
DialogFooter
} from '@/components/ui/Dialog' } from '@/components/ui/Dialog'
import Input from '@/components/ui/Input'
import Checkbox from '@/components/ui/Checkbox'
import { toast } from 'sonner' import { toast } from 'sonner'
import { errorMessage } from '@/lib/utils' import { errorMessage } from '@/lib/utils'
import { clearDocuments } from '@/api/lightrag' import { clearDocuments, clearCache } from '@/api/lightrag'
import { EraserIcon } from 'lucide-react' import { EraserIcon, AlertTriangleIcon } from 'lucide-react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
export default function ClearDocumentsDialog() { // 简单的Label组件
const Label = ({
htmlFor,
className,
children,
...props
}: React.LabelHTMLAttributes<HTMLLabelElement>) => (
<label
htmlFor={htmlFor}
className={className}
{...props}
>
{children}
</label>
)
interface ClearDocumentsDialogProps {
onDocumentsCleared?: () => Promise<void>
}
export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocumentsDialogProps) {
const { t } = useTranslation() const { t } = useTranslation()
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const [confirmText, setConfirmText] = useState('')
const [clearCacheOption, setClearCacheOption] = useState(false)
const isConfirmEnabled = confirmText.toLowerCase() === 'yes'
// 重置状态当对话框关闭时
useEffect(() => {
if (!open) {
setConfirmText('')
setClearCacheOption(false)
}
}, [open])
const handleClear = useCallback(async () => { const handleClear = useCallback(async () => {
if (!isConfirmEnabled) return
try { try {
const result = await clearDocuments() const result = await clearDocuments()
if (result.status === 'success') {
toast.success(t('documentPanel.clearDocuments.success')) if (result.status !== 'success') {
setOpen(false)
} else {
toast.error(t('documentPanel.clearDocuments.failed', { message: result.message })) toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
setConfirmText('')
return
} }
toast.success(t('documentPanel.clearDocuments.success'))
if (clearCacheOption) {
try {
await clearCache()
toast.success(t('documentPanel.clearDocuments.cacheCleared'))
} catch (cacheErr) {
toast.error(t('documentPanel.clearDocuments.cacheClearFailed', { error: errorMessage(cacheErr) }))
}
}
// Refresh document list if provided
if (onDocumentsCleared) {
onDocumentsCleared().catch(console.error)
}
// 所有操作成功后关闭对话框
setOpen(false)
} catch (err) { } catch (err) {
toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) })) toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
setConfirmText('')
} }
}, [setOpen, t]) }, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared])
return ( return (
<Dialog open={open} onOpenChange={setOpen}> <Dialog open={open} onOpenChange={setOpen}>
@@ -42,12 +98,60 @@ export default function ClearDocumentsDialog() {
</DialogTrigger> </DialogTrigger>
<DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}> <DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
<DialogHeader> <DialogHeader>
<DialogTitle>{t('documentPanel.clearDocuments.title')}</DialogTitle> <DialogTitle className="flex items-center gap-2 text-red-500 dark:text-red-400 font-bold">
<DialogDescription>{t('documentPanel.clearDocuments.confirm')}</DialogDescription> <AlertTriangleIcon className="h-5 w-5" />
{t('documentPanel.clearDocuments.title')}
</DialogTitle>
<DialogDescription className="pt-2">
{t('documentPanel.clearDocuments.description')}
</DialogDescription>
</DialogHeader> </DialogHeader>
<Button variant="destructive" onClick={handleClear}>
{t('documentPanel.clearDocuments.confirmButton')} <div className="text-red-500 dark:text-red-400 font-semibold mb-4">
</Button> {t('documentPanel.clearDocuments.warning')}
</div>
<div className="mb-4">
{t('documentPanel.clearDocuments.confirm')}
</div>
<div className="space-y-4">
<div className="space-y-2">
<Label htmlFor="confirm-text" className="text-sm font-medium">
{t('documentPanel.clearDocuments.confirmPrompt')}
</Label>
<Input
id="confirm-text"
value={confirmText}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setConfirmText(e.target.value)}
placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')}
className="w-full"
/>
</div>
<div className="flex items-center space-x-2">
<Checkbox
id="clear-cache"
checked={clearCacheOption}
onCheckedChange={(checked: boolean | 'indeterminate') => setClearCacheOption(checked === true)}
/>
<Label htmlFor="clear-cache" className="text-sm font-medium cursor-pointer">
{t('documentPanel.clearDocuments.clearCache')}
</Label>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setOpen(false)}>
{t('common.cancel')}
</Button>
<Button
variant="destructive"
onClick={handleClear}
disabled={!isConfirmEnabled}
>
{t('documentPanel.clearDocuments.confirmButton')}
</Button>
</DialogFooter>
</DialogContent> </DialogContent>
</Dialog> </Dialog>
) )

View File

@@ -17,7 +17,11 @@ import { uploadDocument } from '@/api/lightrag'
import { UploadIcon } from 'lucide-react' import { UploadIcon } from 'lucide-react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
export default function UploadDocumentsDialog() { interface UploadDocumentsDialogProps {
onDocumentsUploaded?: () => Promise<void>
}
export default function UploadDocumentsDialog({ onDocumentsUploaded }: UploadDocumentsDialogProps) {
const { t } = useTranslation() const { t } = useTranslation()
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const [isUploading, setIsUploading] = useState(false) const [isUploading, setIsUploading] = useState(false)
@@ -55,6 +59,7 @@ export default function UploadDocumentsDialog() {
const handleDocumentsUpload = useCallback( const handleDocumentsUpload = useCallback(
async (filesToUpload: File[]) => { async (filesToUpload: File[]) => {
setIsUploading(true) setIsUploading(true)
let hasSuccessfulUpload = false
// Only clear errors for files that are being uploaded, keep errors for rejected files // Only clear errors for files that are being uploaded, keep errors for rejected files
setFileErrors(prev => { setFileErrors(prev => {
@@ -101,6 +106,9 @@ export default function UploadDocumentsDialog() {
...prev, ...prev,
[file.name]: result.message [file.name]: result.message
})) }))
} else {
// Mark that we had at least one successful upload
hasSuccessfulUpload = true
} }
} catch (err) { } catch (err) {
console.error(`Upload failed for ${file.name}:`, err) console.error(`Upload failed for ${file.name}:`, err)
@@ -142,6 +150,16 @@ export default function UploadDocumentsDialog() {
} else { } else {
toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId }) toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId })
} }
// Only update if at least one file was uploaded successfully
if (hasSuccessfulUpload) {
// Refresh document list
if (onDocumentsUploaded) {
onDocumentsUploaded().catch(err => {
console.error('Error refreshing documents:', err)
})
}
}
} catch (err) { } catch (err) {
console.error('Unexpected error during upload:', err) console.error('Unexpected error during upload:', err)
toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId }) toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId })
@@ -149,7 +167,7 @@ export default function UploadDocumentsDialog() {
setIsUploading(false) setIsUploading(false)
} }
}, },
[setIsUploading, setProgresses, setFileErrors, t] [setIsUploading, setProgresses, setFileErrors, t, onDocumentsUploaded]
) )
return ( return (

View File

@@ -36,6 +36,8 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
const enableEdgeEvents = useSettingsStore.use.enableEdgeEvents() const enableEdgeEvents = useSettingsStore.use.enableEdgeEvents()
const renderEdgeLabels = useSettingsStore.use.showEdgeLabel() const renderEdgeLabels = useSettingsStore.use.showEdgeLabel()
const renderLabels = useSettingsStore.use.showNodeLabel() const renderLabels = useSettingsStore.use.showNodeLabel()
const minEdgeSize = useSettingsStore.use.minEdgeSize()
const maxEdgeSize = useSettingsStore.use.maxEdgeSize()
const selectedNode = useGraphStore.use.selectedNode() const selectedNode = useGraphStore.use.selectedNode()
const focusedNode = useGraphStore.use.focusedNode() const focusedNode = useGraphStore.use.focusedNode()
const selectedEdge = useGraphStore.use.selectedEdge() const selectedEdge = useGraphStore.use.selectedEdge()
@@ -136,6 +138,51 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
registerEvents(events) registerEvents(events)
}, [registerEvents, enableEdgeEvents]) }, [registerEvents, enableEdgeEvents])
/**
* When edge size settings change, recalculate edge sizes and refresh the sigma instance
* to ensure changes take effect immediately
*/
useEffect(() => {
if (sigma && sigmaGraph) {
// Get the graph from sigma
const graph = sigma.getGraph()
// Find min and max weight values
let minWeight = Number.MAX_SAFE_INTEGER
let maxWeight = 0
graph.forEachEdge(edge => {
// Get original weight (before scaling)
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
if (typeof weight === 'number') {
minWeight = Math.min(minWeight, weight)
maxWeight = Math.max(maxWeight, weight)
}
})
// Scale edge sizes based on weight range and current min/max edge size settings
const weightRange = maxWeight - minWeight
if (weightRange > 0) {
const sizeScale = maxEdgeSize - minEdgeSize
graph.forEachEdge(edge => {
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
if (typeof weight === 'number') {
const scaledSize = minEdgeSize + sizeScale * Math.pow((weight - minWeight) / weightRange, 0.5)
graph.setEdgeAttribute(edge, 'size', scaledSize)
}
})
} else {
// If all weights are the same, use default size
graph.forEachEdge(edge => {
graph.setEdgeAttribute(edge, 'size', minEdgeSize)
})
}
// Refresh the sigma instance to apply changes
sigma.refresh()
}
}, [sigma, sigmaGraph, minEdgeSize, maxEdgeSize])
/** /**
* When component mount or hovered node change * When component mount or hovered node change
* => Setting the sigma reducers * => Setting the sigma reducers

View File

@@ -1,4 +1,4 @@
import { useCallback } from 'react' import { useCallback, useEffect } from 'react'
import { AsyncSelect } from '@/components/ui/AsyncSelect' import { AsyncSelect } from '@/components/ui/AsyncSelect'
import { useSettingsStore } from '@/stores/settings' import { useSettingsStore } from '@/stores/settings'
import { useGraphStore } from '@/stores/graph' import { useGraphStore } from '@/stores/graph'
@@ -56,6 +56,23 @@ const GraphLabels = () => {
[getSearchEngine] [getSearchEngine]
) )
// Validate if current queryLabel exists in allDatabaseLabels
useEffect(() => {
// Only update label when all conditions are met:
// 1. allDatabaseLabels is loaded (length > 1, as it has at least '*' by default)
// 2. Current label is not the default '*'
// 3. Current label doesn't exist in allDatabaseLabels
if (
allDatabaseLabels.length > 1 &&
label &&
label !== '*' &&
!allDatabaseLabels.includes(label)
) {
console.log(`Label "${label}" not found in available labels, resetting to default`);
useSettingsStore.getState().setQueryLabel('*');
}
}, [allDatabaseLabels, label]);
const handleRefresh = useCallback(() => { const handleRefresh = useCallback(() => {
// Reset fetch status flags // Reset fetch status flags
useGraphStore.getState().setLabelsFetchAttempted(false) useGraphStore.getState().setLabelsFetchAttempted(false)

View File

@@ -0,0 +1,41 @@
import React from 'react'
import { useTranslation } from 'react-i18next'
import { useGraphStore } from '@/stores/graph'
import { Card } from '@/components/ui/Card'
import { ScrollArea } from '@/components/ui/ScrollArea'
interface LegendProps {
className?: string
}
const Legend: React.FC<LegendProps> = ({ className }) => {
const { t } = useTranslation()
const typeColorMap = useGraphStore.use.typeColorMap()
if (!typeColorMap || typeColorMap.size === 0) {
return null
}
return (
<Card className={`p-2 max-w-xs ${className}`}>
<h3 className="text-sm font-medium mb-2">{t('graphPanel.legend')}</h3>
<ScrollArea className="max-h-40">
<div className="flex flex-col gap-1">
{Array.from(typeColorMap.entries()).map(([type, color]) => (
<div key={type} className="flex items-center gap-2">
<div
className="w-4 h-4 rounded-full"
style={{ backgroundColor: color }}
/>
<span className="text-xs truncate" title={type}>
{type}
</span>
</div>
))}
</div>
</ScrollArea>
</Card>
)
}
export default Legend

View File

@@ -0,0 +1,32 @@
import { useCallback } from 'react'
import { BookOpenIcon } from 'lucide-react'
import Button from '@/components/ui/Button'
import { controlButtonVariant } from '@/lib/constants'
import { useSettingsStore } from '@/stores/settings'
import { useTranslation } from 'react-i18next'
/**
* Component that toggles legend visibility.
*/
const LegendButton = () => {
const { t } = useTranslation()
const showLegend = useSettingsStore.use.showLegend()
const setShowLegend = useSettingsStore.use.setShowLegend()
const toggleLegend = useCallback(() => {
setShowLegend(!showLegend)
}, [showLegend, setShowLegend])
return (
<Button
variant={controlButtonVariant}
onClick={toggleLegend}
tooltip={t('graphPanel.sideBar.legendControl.toggleLegend')}
size="icon"
>
<BookOpenIcon />
</Button>
)
}
export default LegendButton

View File

@@ -8,7 +8,7 @@ import Input from '@/components/ui/Input'
import { controlButtonVariant } from '@/lib/constants' import { controlButtonVariant } from '@/lib/constants'
import { useSettingsStore } from '@/stores/settings' import { useSettingsStore } from '@/stores/settings'
import { SettingsIcon } from 'lucide-react' import { SettingsIcon, Undo2 } from 'lucide-react'
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
/** /**
@@ -44,14 +44,17 @@ const LabeledNumberInput = ({
onEditFinished, onEditFinished,
label, label,
min, min,
max max,
defaultValue
}: { }: {
value: number value: number
onEditFinished: (value: number) => void onEditFinished: (value: number) => void
label: string label: string
min: number min: number
max?: number max?: number
defaultValue?: number
}) => { }) => {
const { t } = useTranslation();
const [currentValue, setCurrentValue] = useState<number | null>(value) const [currentValue, setCurrentValue] = useState<number | null>(value)
const onValueChange = useCallback( const onValueChange = useCallback(
@@ -81,6 +84,13 @@ const LabeledNumberInput = ({
} }
}, [value, currentValue, onEditFinished]) }, [value, currentValue, onEditFinished])
const handleReset = useCallback(() => {
if (defaultValue !== undefined && value !== defaultValue) {
setCurrentValue(defaultValue)
onEditFinished(defaultValue)
}
}, [defaultValue, value, onEditFinished])
return ( return (
<div className="flex flex-col gap-2"> <div className="flex flex-col gap-2">
<label <label
@@ -89,20 +99,34 @@ const LabeledNumberInput = ({
> >
{label} {label}
</label> </label>
<Input <div className="flex items-center gap-1">
type="number" <Input
value={currentValue === null ? '' : currentValue} type="number"
onChange={onValueChange} value={currentValue === null ? '' : currentValue}
className="h-6 w-full min-w-0 pr-1" onChange={onValueChange}
min={min} className="h-6 w-full min-w-0 pr-1"
max={max} min={min}
onBlur={onBlur} max={max}
onKeyDown={(e) => { onBlur={onBlur}
if (e.key === 'Enter') { onKeyDown={(e) => {
onBlur() if (e.key === 'Enter') {
} onBlur()
}} }
/> }}
/>
{defaultValue !== undefined && (
<Button
variant="ghost"
size="icon"
className="h-6 w-6 flex-shrink-0 hover:bg-muted text-muted-foreground hover:text-foreground"
onClick={handleReset}
type="button"
title={t('graphPanel.sideBar.settings.resetToDefault')}
>
<Undo2 className="h-3.5 w-3.5" />
</Button>
)}
</div>
</div> </div>
) )
} }
@@ -120,8 +144,10 @@ export default function Settings() {
const enableNodeDrag = useSettingsStore.use.enableNodeDrag() const enableNodeDrag = useSettingsStore.use.enableNodeDrag()
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges() const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
const showEdgeLabel = useSettingsStore.use.showEdgeLabel() const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
const minEdgeSize = useSettingsStore.use.minEdgeSize()
const maxEdgeSize = useSettingsStore.use.maxEdgeSize()
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth() const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
const graphMinDegree = useSettingsStore.use.graphMinDegree() const graphMaxNodes = useSettingsStore.use.graphMaxNodes()
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations() const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
const enableHealthCheck = useSettingsStore.use.enableHealthCheck() const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
@@ -180,15 +206,14 @@ export default function Settings() {
}, 300) }, 300)
}, []) }, [])
const setGraphMinDegree = useCallback((degree: number) => { const setGraphMaxNodes = useCallback((nodes: number) => {
if (degree < 0) return if (nodes < 1 || nodes > 1000) return
useSettingsStore.setState({ graphMinDegree: degree }) useSettingsStore.setState({ graphMaxNodes: nodes })
const currentLabel = useSettingsStore.getState().queryLabel const currentLabel = useSettingsStore.getState().queryLabel
useSettingsStore.getState().setQueryLabel('') useSettingsStore.getState().setQueryLabel('')
setTimeout(() => { setTimeout(() => {
useSettingsStore.getState().setQueryLabel(currentLabel) useSettingsStore.getState().setQueryLabel(currentLabel)
}, 300) }, 300)
}, []) }, [])
const setGraphLayoutMaxIterations = useCallback((iterations: number) => { const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
@@ -269,24 +294,75 @@ export default function Settings() {
label={t('graphPanel.sideBar.settings.edgeEvents')} label={t('graphPanel.sideBar.settings.edgeEvents')}
/> />
<div className="flex flex-col gap-2">
<label className="text-sm leading-none font-medium peer-disabled:cursor-not-allowed peer-disabled:opacity-70">
{t('graphPanel.sideBar.settings.edgeSizeRange')}
</label>
<div className="flex items-center gap-2">
<Input
type="number"
value={minEdgeSize}
onChange={(e) => {
const newValue = Number(e.target.value);
if (!isNaN(newValue) && newValue >= 1 && newValue <= maxEdgeSize) {
useSettingsStore.setState({ minEdgeSize: newValue });
}
}}
className="h-6 w-16 min-w-0 pr-1"
min={1}
max={Math.min(maxEdgeSize, 10)}
/>
<span>-</span>
<div className="flex items-center gap-1">
<Input
type="number"
value={maxEdgeSize}
onChange={(e) => {
const newValue = Number(e.target.value);
if (!isNaN(newValue) && newValue >= minEdgeSize && newValue >= 1 && newValue <= 10) {
useSettingsStore.setState({ maxEdgeSize: newValue });
}
}}
className="h-6 w-16 min-w-0 pr-1"
min={minEdgeSize}
max={10}
/>
<Button
variant="ghost"
size="icon"
className="h-6 w-6 flex-shrink-0 hover:bg-muted text-muted-foreground hover:text-foreground"
onClick={() => useSettingsStore.setState({ minEdgeSize: 1, maxEdgeSize: 5 })}
type="button"
title={t('graphPanel.sideBar.settings.resetToDefault')}
>
<Undo2 className="h-3.5 w-3.5" />
</Button>
</div>
</div>
</div>
<Separator /> <Separator />
<LabeledNumberInput <LabeledNumberInput
label={t('graphPanel.sideBar.settings.maxQueryDepth')} label={t('graphPanel.sideBar.settings.maxQueryDepth')}
min={1} min={1}
value={graphQueryMaxDepth} value={graphQueryMaxDepth}
defaultValue={3}
onEditFinished={setGraphQueryMaxDepth} onEditFinished={setGraphQueryMaxDepth}
/> />
<LabeledNumberInput <LabeledNumberInput
label={t('graphPanel.sideBar.settings.minDegree')} label={t('graphPanel.sideBar.settings.maxNodes')}
min={0} min={1}
value={graphMinDegree} max={1000}
onEditFinished={setGraphMinDegree} value={graphMaxNodes}
defaultValue={1000}
onEditFinished={setGraphMaxNodes}
/> />
<LabeledNumberInput <LabeledNumberInput
label={t('graphPanel.sideBar.settings.maxLayoutIterations')} label={t('graphPanel.sideBar.settings.maxLayoutIterations')}
min={1} min={1}
max={30} max={30}
value={graphLayoutMaxIterations} value={graphLayoutMaxIterations}
defaultValue={15}
onEditFinished={setGraphLayoutMaxIterations} onEditFinished={setGraphLayoutMaxIterations}
/> />
<Separator /> <Separator />

View File

@@ -8,12 +8,12 @@ import { useTranslation } from 'react-i18next'
const SettingsDisplay = () => { const SettingsDisplay = () => {
const { t } = useTranslation() const { t } = useTranslation()
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth() const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
const graphMinDegree = useSettingsStore.use.graphMinDegree() const graphMaxNodes = useSettingsStore.use.graphMaxNodes()
return ( return (
<div className="absolute bottom-4 left-[calc(1rem+2.5rem)] flex items-center gap-2 text-xs text-gray-400"> <div className="absolute bottom-4 left-[calc(1rem+2.5rem)] flex items-center gap-2 text-xs text-gray-400">
<div>{t('graphPanel.sideBar.settings.depth')}: {graphQueryMaxDepth}</div> <div>{t('graphPanel.sideBar.settings.depth')}: {graphQueryMaxDepth}</div>
<div>{t('graphPanel.sideBar.settings.degree')}: {graphMinDegree}</div> <div>{t('graphPanel.sideBar.settings.max')}: {graphMaxNodes}</div>
</div> </div>
) )
} }

View File

@@ -4,14 +4,14 @@ import { useTranslation } from 'react-i18next'
const StatusCard = ({ status }: { status: LightragStatus | null }) => { const StatusCard = ({ status }: { status: LightragStatus | null }) => {
const { t } = useTranslation() const { t } = useTranslation()
if (!status) { if (!status) {
return <div className="text-muted-foreground text-sm">{t('graphPanel.statusCard.unavailable')}</div> return <div className="text-foreground text-xs">{t('graphPanel.statusCard.unavailable')}</div>
} }
return ( return (
<div className="min-w-[300px] space-y-3 text-sm"> <div className="min-w-[300px] space-y-2 text-xs">
<div className="space-y-1"> <div className="space-y-1">
<h4 className="font-medium">{t('graphPanel.statusCard.storageInfo')}</h4> <h4 className="font-medium">{t('graphPanel.statusCard.storageInfo')}</h4>
<div className="text-muted-foreground grid grid-cols-2 gap-1"> <div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
<span>{t('graphPanel.statusCard.workingDirectory')}:</span> <span>{t('graphPanel.statusCard.workingDirectory')}:</span>
<span className="truncate">{status.working_directory}</span> <span className="truncate">{status.working_directory}</span>
<span>{t('graphPanel.statusCard.inputDirectory')}:</span> <span>{t('graphPanel.statusCard.inputDirectory')}:</span>
@@ -21,7 +21,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {
<div className="space-y-1"> <div className="space-y-1">
<h4 className="font-medium">{t('graphPanel.statusCard.llmConfig')}</h4> <h4 className="font-medium">{t('graphPanel.statusCard.llmConfig')}</h4>
<div className="text-muted-foreground grid grid-cols-2 gap-1"> <div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
<span>{t('graphPanel.statusCard.llmBinding')}:</span> <span>{t('graphPanel.statusCard.llmBinding')}:</span>
<span>{status.configuration.llm_binding}</span> <span>{status.configuration.llm_binding}</span>
<span>{t('graphPanel.statusCard.llmBindingHost')}:</span> <span>{t('graphPanel.statusCard.llmBindingHost')}:</span>
@@ -35,7 +35,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {
<div className="space-y-1"> <div className="space-y-1">
<h4 className="font-medium">{t('graphPanel.statusCard.embeddingConfig')}</h4> <h4 className="font-medium">{t('graphPanel.statusCard.embeddingConfig')}</h4>
<div className="text-muted-foreground grid grid-cols-2 gap-1"> <div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
<span>{t('graphPanel.statusCard.embeddingBinding')}:</span> <span>{t('graphPanel.statusCard.embeddingBinding')}:</span>
<span>{status.configuration.embedding_binding}</span> <span>{status.configuration.embedding_binding}</span>
<span>{t('graphPanel.statusCard.embeddingBindingHost')}:</span> <span>{t('graphPanel.statusCard.embeddingBindingHost')}:</span>
@@ -47,7 +47,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {
<div className="space-y-1"> <div className="space-y-1">
<h4 className="font-medium">{t('graphPanel.statusCard.storageConfig')}</h4> <h4 className="font-medium">{t('graphPanel.statusCard.storageConfig')}</h4>
<div className="text-muted-foreground grid grid-cols-2 gap-1"> <div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
<span>{t('graphPanel.statusCard.kvStorage')}:</span> <span>{t('graphPanel.statusCard.kvStorage')}:</span>
<span>{status.configuration.kv_storage}</span> <span>{status.configuration.kv_storage}</span>
<span>{t('graphPanel.statusCard.docStatusStorage')}:</span> <span>{t('graphPanel.statusCard.docStatusStorage')}:</span>

View File

@@ -0,0 +1,32 @@
import { LightragStatus } from '@/api/lightrag'
import { useTranslation } from 'react-i18next'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
} from '@/components/ui/Dialog'
import StatusCard from './StatusCard'
interface StatusDialogProps {
open: boolean
onOpenChange: (open: boolean) => void
status: LightragStatus | null
}
const StatusDialog = ({ open, onOpenChange, status }: StatusDialogProps) => {
const { t } = useTranslation()
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="sm:max-w-[500px]">
<DialogHeader>
<DialogTitle>{t('graphPanel.statusDialog.title')}</DialogTitle>
</DialogHeader>
<StatusCard status={status} />
</DialogContent>
</Dialog>
)
}
export default StatusDialog

View File

@@ -1,8 +1,7 @@
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { useBackendState } from '@/stores/state' import { useBackendState } from '@/stores/state'
import { useEffect, useState } from 'react' import { useEffect, useState } from 'react'
import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/Popover' import StatusDialog from './StatusDialog'
import StatusCard from '@/components/status/StatusCard'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
const StatusIndicator = () => { const StatusIndicator = () => {
@@ -11,6 +10,7 @@ const StatusIndicator = () => {
const lastCheckTime = useBackendState.use.lastCheckTime() const lastCheckTime = useBackendState.use.lastCheckTime()
const status = useBackendState.use.status() const status = useBackendState.use.status()
const [animate, setAnimate] = useState(false) const [animate, setAnimate] = useState(false)
const [dialogOpen, setDialogOpen] = useState(false)
// listen to health change // listen to health change
useEffect(() => { useEffect(() => {
@@ -21,28 +21,30 @@ const StatusIndicator = () => {
return ( return (
<div className="fixed right-4 bottom-4 flex items-center gap-2 opacity-80 select-none"> <div className="fixed right-4 bottom-4 flex items-center gap-2 opacity-80 select-none">
<Popover> <div
<PopoverTrigger asChild> className="flex cursor-pointer items-center gap-2"
<div className="flex cursor-help items-center gap-2"> onClick={() => setDialogOpen(true)}
<div >
className={cn( <div
'h-3 w-3 rounded-full transition-all duration-300', className={cn(
'shadow-[0_0_8px_rgba(0,0,0,0.2)]', 'h-3 w-3 rounded-full transition-all duration-300',
health ? 'bg-green-500' : 'bg-red-500', 'shadow-[0_0_8px_rgba(0,0,0,0.2)]',
animate && 'scale-125', health ? 'bg-green-500' : 'bg-red-500',
animate && health && 'shadow-[0_0_12px_rgba(34,197,94,0.4)]', animate && 'scale-125',
animate && !health && 'shadow-[0_0_12px_rgba(239,68,68,0.4)]' animate && health && 'shadow-[0_0_12px_rgba(34,197,94,0.4)]',
)} animate && !health && 'shadow-[0_0_12px_rgba(239,68,68,0.4)]'
/> )}
<span className="text-muted-foreground text-xs"> />
{health ? t('graphPanel.statusIndicator.connected') : t('graphPanel.statusIndicator.disconnected')} <span className="text-muted-foreground text-xs">
</span> {health ? t('graphPanel.statusIndicator.connected') : t('graphPanel.statusIndicator.disconnected')}
</div> </span>
</PopoverTrigger> </div>
<PopoverContent className="w-auto" side="top" align="end">
<StatusCard status={status} /> <StatusDialog
</PopoverContent> open={dialogOpen}
</Popover> onOpenChange={setDialogOpen}
status={status}
/>
</div> </div>
) )
} }

View File

@@ -11,7 +11,7 @@ const Checkbox = React.forwardRef<
<CheckboxPrimitive.Root <CheckboxPrimitive.Root
ref={ref} ref={ref}
className={cn( className={cn(
'peer border-primary ring-offset-background focus-visible:ring-ring data-[state=checked]:bg-primary data-[state=checked]:text-primary-foreground h-4 w-4 shrink-0 rounded-sm border focus-visible:ring-2 focus-visible:ring-offset-2 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50', 'peer border-primary ring-offset-background focus-visible:ring-ring data-[state=checked]:bg-muted data-[state=checked]:text-muted-foreground h-4 w-4 shrink-0 rounded-sm border focus-visible:ring-2 focus-visible:ring-offset-2 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50',
className className
)} )}
{...props} {...props}

View File

@@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<'input'>>(
<input <input
type={type} type={type}
className={cn( className={cn(
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-100 [&::-webkit-outer-spin-button]:opacity-100', 'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-50 [&::-webkit-outer-spin-button]:opacity-50',
className className
)} )}
ref={ref} ref={ref}

View File

@@ -1,4 +1,4 @@
import { useState, useEffect, useCallback, useRef } from 'react' import { useState, useEffect, useCallback, useMemo, useRef } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useSettingsStore } from '@/stores/settings' import { useSettingsStore } from '@/stores/settings'
import Button from '@/components/ui/Button' import Button from '@/components/ui/Button'
@@ -16,15 +16,17 @@ import EmptyCard from '@/components/ui/EmptyCard'
import UploadDocumentsDialog from '@/components/documents/UploadDocumentsDialog' import UploadDocumentsDialog from '@/components/documents/UploadDocumentsDialog'
import ClearDocumentsDialog from '@/components/documents/ClearDocumentsDialog' import ClearDocumentsDialog from '@/components/documents/ClearDocumentsDialog'
import { getDocuments, scanNewDocuments, DocsStatusesResponse } from '@/api/lightrag' import { getDocuments, scanNewDocuments, DocsStatusesResponse, DocStatus, DocStatusResponse } from '@/api/lightrag'
import { errorMessage } from '@/lib/utils' import { errorMessage } from '@/lib/utils'
import { toast } from 'sonner' import { toast } from 'sonner'
import { useBackendState } from '@/stores/state' import { useBackendState } from '@/stores/state'
import { RefreshCwIcon, ActivityIcon, ArrowUpIcon, ArrowDownIcon } from 'lucide-react' import { RefreshCwIcon, ActivityIcon, ArrowUpIcon, ArrowDownIcon, FilterIcon } from 'lucide-react'
import { DocStatusResponse } from '@/api/lightrag'
import PipelineStatusDialog from '@/components/documents/PipelineStatusDialog' import PipelineStatusDialog from '@/components/documents/PipelineStatusDialog'
type StatusFilter = DocStatus | 'all';
const getDisplayFileName = (doc: DocStatusResponse, maxLength: number = 20): string => { const getDisplayFileName = (doc: DocStatusResponse, maxLength: number = 20): string => {
// Check if file_path exists and is a non-empty string // Check if file_path exists and is a non-empty string
if (!doc.file_path || typeof doc.file_path !== 'string' || doc.file_path.trim() === '') { if (!doc.file_path || typeof doc.file_path !== 'string' || doc.file_path.trim() === '') {
@@ -148,6 +150,10 @@ export default function DocumentManager() {
const [sortField, setSortField] = useState<SortField>('updated_at') const [sortField, setSortField] = useState<SortField>('updated_at')
const [sortDirection, setSortDirection] = useState<SortDirection>('desc') const [sortDirection, setSortDirection] = useState<SortDirection>('desc')
// State for document status filter
const [statusFilter, setStatusFilter] = useState<StatusFilter>('all');
// Handle sort column click // Handle sort column click
const handleSort = (field: SortField) => { const handleSort = (field: SortField) => {
if (sortField === field) { if (sortField === field) {
@@ -161,7 +167,7 @@ export default function DocumentManager() {
} }
// Sort documents based on current sort field and direction // Sort documents based on current sort field and direction
const sortDocuments = (documents: DocStatusResponse[]) => { const sortDocuments = useCallback((documents: DocStatusResponse[]) => {
return [...documents].sort((a, b) => { return [...documents].sort((a, b) => {
let valueA, valueB; let valueA, valueB;
@@ -188,7 +194,50 @@ export default function DocumentManager() {
return sortMultiplier * (valueA > valueB ? 1 : valueA < valueB ? -1 : 0); return sortMultiplier * (valueA > valueB ? 1 : valueA < valueB ? -1 : 0);
} }
}); });
} }, [sortField, sortDirection, showFileName]);
const filteredAndSortedDocs = useMemo(() => {
if (!docs) return null;
let filteredDocs = { ...docs };
if (statusFilter !== 'all') {
filteredDocs = {
...docs,
statuses: {
pending: [],
processing: [],
processed: [],
failed: [],
[statusFilter]: docs.statuses[statusFilter] || []
}
};
}
if (!sortField || !sortDirection) return filteredDocs;
const sortedStatuses = Object.entries(filteredDocs.statuses).reduce((acc, [status, documents]) => {
const sortedDocuments = sortDocuments(documents);
acc[status as DocStatus] = sortedDocuments;
return acc;
}, {} as DocsStatusesResponse['statuses']);
return { ...filteredDocs, statuses: sortedStatuses };
}, [docs, sortField, sortDirection, statusFilter, sortDocuments]);
// Calculate document counts for each status
const documentCounts = useMemo(() => {
if (!docs) return { all: 0 } as Record<string, number>;
const counts: Record<string, number> = { all: 0 };
Object.entries(docs.statuses).forEach(([status, documents]) => {
counts[status as DocStatus] = documents.length;
counts.all += documents.length;
});
return counts;
}, [docs]);
// Store previous status counts // Store previous status counts
const prevStatusCounts = useRef({ const prevStatusCounts = useRef({
@@ -386,8 +435,8 @@ export default function DocumentManager() {
</Button> </Button>
</div> </div>
<div className="flex-1" /> <div className="flex-1" />
<ClearDocumentsDialog /> <ClearDocumentsDialog onDocumentsCleared={fetchDocuments} />
<UploadDocumentsDialog /> <UploadDocumentsDialog onDocumentsUploaded={fetchDocuments} />
<PipelineStatusDialog <PipelineStatusDialog
open={showPipelineStatus} open={showPipelineStatus}
onOpenChange={setShowPipelineStatus} onOpenChange={setShowPipelineStatus}
@@ -398,6 +447,65 @@ export default function DocumentManager() {
<CardHeader className="flex-none py-2 px-4"> <CardHeader className="flex-none py-2 px-4">
<div className="flex justify-between items-center"> <div className="flex justify-between items-center">
<CardTitle>{t('documentPanel.documentManager.uploadedTitle')}</CardTitle> <CardTitle>{t('documentPanel.documentManager.uploadedTitle')}</CardTitle>
<div className="flex items-center gap-2">
<FilterIcon className="h-4 w-4" />
<div className="flex gap-1">
<Button
size="sm"
variant={statusFilter === 'all' ? 'secondary' : 'outline'}
onClick={() => setStatusFilter('all')}
className={cn(
statusFilter === 'all' && 'bg-gray-100 dark:bg-gray-900 font-medium border border-gray-400 dark:border-gray-500 shadow-sm'
)}
>
{t('documentPanel.documentManager.status.all')} ({documentCounts.all})
</Button>
<Button
size="sm"
variant={statusFilter === 'processed' ? 'secondary' : 'outline'}
onClick={() => setStatusFilter('processed')}
className={cn(
documentCounts.processed > 0 ? 'text-green-600' : 'text-gray-500',
statusFilter === 'processed' && 'bg-green-100 dark:bg-green-900/30 font-medium border border-green-400 dark:border-green-600 shadow-sm'
)}
>
{t('documentPanel.documentManager.status.completed')} ({documentCounts.processed || 0})
</Button>
<Button
size="sm"
variant={statusFilter === 'processing' ? 'secondary' : 'outline'}
onClick={() => setStatusFilter('processing')}
className={cn(
documentCounts.processing > 0 ? 'text-blue-600' : 'text-gray-500',
statusFilter === 'processing' && 'bg-blue-100 dark:bg-blue-900/30 font-medium border border-blue-400 dark:border-blue-600 shadow-sm'
)}
>
{t('documentPanel.documentManager.status.processing')} ({documentCounts.processing || 0})
</Button>
<Button
size="sm"
variant={statusFilter === 'pending' ? 'secondary' : 'outline'}
onClick={() => setStatusFilter('pending')}
className={cn(
documentCounts.pending > 0 ? 'text-yellow-600' : 'text-gray-500',
statusFilter === 'pending' && 'bg-yellow-100 dark:bg-yellow-900/30 font-medium border border-yellow-400 dark:border-yellow-600 shadow-sm'
)}
>
{t('documentPanel.documentManager.status.pending')} ({documentCounts.pending || 0})
</Button>
<Button
size="sm"
variant={statusFilter === 'failed' ? 'secondary' : 'outline'}
onClick={() => setStatusFilter('failed')}
className={cn(
documentCounts.failed > 0 ? 'text-red-600' : 'text-gray-500',
statusFilter === 'failed' && 'bg-red-100 dark:bg-red-900/30 font-medium border border-red-400 dark:border-red-600 shadow-sm'
)}
>
{t('documentPanel.documentManager.status.failed')} ({documentCounts.failed || 0})
</Button>
</div>
</div>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<span className="text-sm text-gray-500">{t('documentPanel.documentManager.fileNameLabel')}</span> <span className="text-sm text-gray-500">{t('documentPanel.documentManager.fileNameLabel')}</span>
<Button <Button
@@ -477,11 +585,8 @@ export default function DocumentManager() {
</TableRow> </TableRow>
</TableHeader> </TableHeader>
<TableBody className="text-sm overflow-auto"> <TableBody className="text-sm overflow-auto">
{Object.entries(docs.statuses).flatMap(([status, documents]) => { {filteredAndSortedDocs?.statuses && Object.entries(filteredAndSortedDocs.statuses).flatMap(([status, documents]) =>
// Apply sorting to documents documents.map((doc) => (
const sortedDocuments = sortDocuments(documents);
return sortedDocuments.map(doc => (
<TableRow key={doc.id}> <TableRow key={doc.id}>
<TableCell className="truncate font-mono overflow-visible max-w-[250px]"> <TableCell className="truncate font-mono overflow-visible max-w-[250px]">
{showFileName ? ( {showFileName ? (
@@ -541,8 +646,8 @@ export default function DocumentManager() {
{new Date(doc.updated_at).toLocaleString()} {new Date(doc.updated_at).toLocaleString()}
</TableCell> </TableCell>
</TableRow> </TableRow>
)); )))
})} }
</TableBody> </TableBody>
</Table> </Table>
</div> </div>

View File

@@ -18,6 +18,8 @@ import GraphSearch from '@/components/graph/GraphSearch'
import GraphLabels from '@/components/graph/GraphLabels' import GraphLabels from '@/components/graph/GraphLabels'
import PropertiesView from '@/components/graph/PropertiesView' import PropertiesView from '@/components/graph/PropertiesView'
import SettingsDisplay from '@/components/graph/SettingsDisplay' import SettingsDisplay from '@/components/graph/SettingsDisplay'
import Legend from '@/components/graph/Legend'
import LegendButton from '@/components/graph/LegendButton'
import { useSettingsStore } from '@/stores/settings' import { useSettingsStore } from '@/stores/settings'
import { useGraphStore } from '@/stores/graph' import { useGraphStore } from '@/stores/graph'
@@ -116,6 +118,7 @@ const GraphViewer = () => {
const showPropertyPanel = useSettingsStore.use.showPropertyPanel() const showPropertyPanel = useSettingsStore.use.showPropertyPanel()
const showNodeSearchBar = useSettingsStore.use.showNodeSearchBar() const showNodeSearchBar = useSettingsStore.use.showNodeSearchBar()
const enableNodeDrag = useSettingsStore.use.enableNodeDrag() const enableNodeDrag = useSettingsStore.use.enableNodeDrag()
const showLegend = useSettingsStore.use.showLegend()
// Initialize sigma settings once on component mount // Initialize sigma settings once on component mount
// All dynamic settings will be updated in GraphControl using useSetSettings // All dynamic settings will be updated in GraphControl using useSetSettings
@@ -195,6 +198,7 @@ const GraphViewer = () => {
<LayoutsControl /> <LayoutsControl />
<ZoomControl /> <ZoomControl />
<FullScreenControl /> <FullScreenControl />
<LegendButton />
<Settings /> <Settings />
{/* <ThemeToggle /> */} {/* <ThemeToggle /> */}
</div> </div>
@@ -205,6 +209,12 @@ const GraphViewer = () => {
</div> </div>
)} )}
{showLegend && (
<div className="absolute bottom-10 right-2">
<Legend className="bg-background/60 backdrop-blur-lg" />
</div>
)}
{/* <div className="absolute bottom-2 right-2 flex flex-col rounded-xl border-2"> {/* <div className="absolute bottom-2 right-2 flex flex-col rounded-xl border-2">
<MiniMap width="100px" height="100px" /> <MiniMap width="100px" height="100px" />
</div> */} </div> */}

View File

@@ -51,7 +51,7 @@ const LoginPage = () => {
if (!status.auth_configured && status.access_token) { if (!status.auth_configured && status.access_token) {
// If auth is not configured, use the guest token and redirect // If auth is not configured, use the guest token and redirect
login(status.access_token, true, status.core_version, status.api_version) login(status.access_token, true, status.core_version, status.api_version, status.webui_title || null, status.webui_description || null)
if (status.message) { if (status.message) {
toast.info(status.message) toast.info(status.message)
} }
@@ -96,7 +96,7 @@ const LoginPage = () => {
// Check authentication mode // Check authentication mode
const isGuestMode = response.auth_mode === 'disabled' const isGuestMode = response.auth_mode === 'disabled'
login(response.access_token, isGuestMode, response.core_version, response.api_version) login(response.access_token, isGuestMode, response.core_version, response.api_version, response.webui_title || null, response.webui_description || null)
// Set session flag for version check // Set session flag for version check
if (response.core_version || response.api_version) { if (response.core_version || response.api_version) {

View File

@@ -8,6 +8,7 @@ import { cn } from '@/lib/utils'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { navigationService } from '@/services/navigation' import { navigationService } from '@/services/navigation'
import { ZapIcon, GithubIcon, LogOutIcon } from 'lucide-react' import { ZapIcon, GithubIcon, LogOutIcon } from 'lucide-react'
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip'
interface NavigationTabProps { interface NavigationTabProps {
value: string value: string
@@ -55,7 +56,7 @@ function TabsNavigation() {
export default function SiteHeader() { export default function SiteHeader() {
const { t } = useTranslation() const { t } = useTranslation()
const { isGuestMode, coreVersion, apiVersion, username } = useAuthStore() const { isGuestMode, coreVersion, apiVersion, username, webuiTitle, webuiDescription } = useAuthStore()
const versionDisplay = (coreVersion && apiVersion) const versionDisplay = (coreVersion && apiVersion)
? `${coreVersion}/${apiVersion}` ? `${coreVersion}/${apiVersion}`
@@ -67,17 +68,31 @@ export default function SiteHeader() {
return ( return (
<header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur"> <header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
<div className="w-[200px] flex items-center"> <div className="min-w-[200px] w-auto flex items-center">
<a href={webuiPrefix} className="flex items-center gap-2"> <a href={webuiPrefix} className="flex items-center gap-2">
<ZapIcon className="size-4 text-emerald-400" aria-hidden="true" /> <ZapIcon className="size-4 text-emerald-400" aria-hidden="true" />
{/* <img src='/logo.png' className="size-4" /> */} {/* <img src='/logo.png' className="size-4" /> */}
<span className="font-bold md:inline-block">{SiteInfo.name}</span> <span className="font-bold md:inline-block">{SiteInfo.name}</span>
{versionDisplay && (
<span className="ml-2 text-xs text-gray-500 dark:text-gray-400">
v{versionDisplay}
</span>
)}
</a> </a>
{webuiTitle && (
<div className="flex items-center">
<span className="mx-1 text-xs text-gray-500 dark:text-gray-400">|</span>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<span className="font-medium text-sm cursor-default">
{webuiTitle}
</span>
</TooltipTrigger>
{webuiDescription && (
<TooltipContent side="bottom">
{webuiDescription}
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
</div>
)}
</div> </div>
<div className="flex h-10 flex-1 items-center justify-center"> <div className="flex h-10 flex-1 items-center justify-center">
@@ -91,6 +106,11 @@ export default function SiteHeader() {
<nav className="w-[200px] flex items-center justify-end"> <nav className="w-[200px] flex items-center justify-end">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
{versionDisplay && (
<span className="text-xs text-gray-500 dark:text-gray-400 mr-1">
v{versionDisplay}
</span>
)}
<Button variant="ghost" size="icon" side="bottom" tooltip={t('header.projectRepository')}> <Button variant="ghost" size="icon" side="bottom" tooltip={t('header.projectRepository')}>
<a href={SiteInfo.github} target="_blank" rel="noopener noreferrer"> <a href={SiteInfo.github} target="_blank" rel="noopener noreferrer">
<GithubIcon className="size-4" aria-hidden="true" /> <GithubIcon className="size-4" aria-hidden="true" />

View File

@@ -11,6 +11,35 @@ import { useSettingsStore } from '@/stores/settings'
import seedrandom from 'seedrandom' import seedrandom from 'seedrandom'
// Helper function to generate a color based on type
const getNodeColorByType = (nodeType: string | undefined): string => {
const defaultColor = '#CCCCCC'; // Default color for nodes without a type or undefined type
if (!nodeType) {
return defaultColor;
}
const typeColorMap = useGraphStore.getState().typeColorMap;
if (!typeColorMap.has(nodeType)) {
// Generate a color based on the type string itself for consistency
// Seed the global random number generator based on the node type
seedrandom(nodeType, { global: true });
// Call randomColor without arguments; it will use the globally seeded Math.random()
const newColor = randomColor();
const newMap = new Map(typeColorMap);
newMap.set(nodeType, newColor);
useGraphStore.setState({ typeColorMap: newMap });
return newColor;
}
// Restore the default random seed if necessary, though usually not required for this use case
// seedrandom(Date.now().toString(), { global: true });
return typeColorMap.get(nodeType) || defaultColor; // Add fallback just in case
};
const validateGraph = (graph: RawGraph) => { const validateGraph = (graph: RawGraph) => {
// Check if graph exists // Check if graph exists
if (!graph) { if (!graph) {
@@ -68,9 +97,15 @@ export type NodeType = {
color: string color: string
highlighted?: boolean highlighted?: boolean
} }
export type EdgeType = { label: string } export type EdgeType = {
label: string
originalWeight?: number
size?: number
color?: string
hidden?: boolean
}
const fetchGraph = async (label: string, maxDepth: number, minDegree: number) => { const fetchGraph = async (label: string, maxDepth: number, maxNodes: number) => {
let rawData: any = null; let rawData: any = null;
// Check if we need to fetch all database labels first // Check if we need to fetch all database labels first
@@ -89,8 +124,8 @@ const fetchGraph = async (label: string, maxDepth: number, minDegree: number) =>
const queryLabel = label || '*'; const queryLabel = label || '*';
try { try {
console.log(`Fetching graph label: ${queryLabel}, depth: ${maxDepth}, deg: ${minDegree}`); console.log(`Fetching graph label: ${queryLabel}, depth: ${maxDepth}, nodes: ${maxNodes}`);
rawData = await queryGraphs(queryLabel, maxDepth, minDegree); rawData = await queryGraphs(queryLabel, maxDepth, maxNodes);
} catch (e) { } catch (e) {
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!'); useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!');
return null; return null;
@@ -106,9 +141,6 @@ const fetchGraph = async (label: string, maxDepth: number, minDegree: number) =>
const node = rawData.nodes[i] const node = rawData.nodes[i]
nodeIdMap[node.id] = i nodeIdMap[node.id] = i
// const seed = node.labels.length > 0 ? node.labels[0] : node.id
seedrandom(node.id, { global: true })
node.color = randomColor()
node.x = Math.random() node.x = Math.random()
node.y = Math.random() node.y = Math.random()
node.degree = 0 node.degree = 0
@@ -169,11 +201,14 @@ const fetchGraph = async (label: string, maxDepth: number, minDegree: number) =>
} }
// console.debug({ data: JSON.parse(JSON.stringify(rawData)) }) // console.debug({ data: JSON.parse(JSON.stringify(rawData)) })
return rawGraph return { rawGraph, is_truncated: rawData.is_truncated }
} }
// Create a new graph instance with the raw graph data // Create a new graph instance with the raw graph data
const createSigmaGraph = (rawGraph: RawGraph | null) => { const createSigmaGraph = (rawGraph: RawGraph | null) => {
// Get edge size settings from store
const minEdgeSize = useSettingsStore.getState().minEdgeSize
const maxEdgeSize = useSettingsStore.getState().maxEdgeSize
// Skip graph creation if no data or empty nodes // Skip graph creation if no data or empty nodes
if (!rawGraph || !rawGraph.nodes.length) { if (!rawGraph || !rawGraph.nodes.length) {
console.log('No graph data available, skipping sigma graph creation'); console.log('No graph data available, skipping sigma graph creation');
@@ -204,8 +239,40 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
// Add edges from raw graph data // Add edges from raw graph data
for (const rawEdge of rawGraph?.edges ?? []) { for (const rawEdge of rawGraph?.edges ?? []) {
// Get weight from edge properties or default to 1
const weight = rawEdge.properties?.weight !== undefined ? Number(rawEdge.properties.weight) : 1
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, { rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
label: rawEdge.properties?.keywords || undefined label: rawEdge.properties?.keywords || undefined,
size: weight, // Set initial size based on weight
originalWeight: weight, // Store original weight for recalculation
})
}
// Calculate edge size based on weight range, similar to node size calculation
let minWeight = Number.MAX_SAFE_INTEGER
let maxWeight = 0
// Find min and max weight values
graph.forEachEdge(edge => {
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
minWeight = Math.min(minWeight, weight)
maxWeight = Math.max(maxWeight, weight)
})
// Scale edge sizes based on weight range
const weightRange = maxWeight - minWeight
if (weightRange > 0) {
const sizeScale = maxEdgeSize - minEdgeSize
graph.forEachEdge(edge => {
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
const scaledSize = minEdgeSize + sizeScale * Math.pow((weight - minWeight) / weightRange, 0.5)
graph.setEdgeAttribute(edge, 'size', scaledSize)
})
} else {
// If all weights are the same, use default size
graph.forEachEdge(edge => {
graph.setEdgeAttribute(edge, 'size', minEdgeSize)
}) })
} }
@@ -218,11 +285,12 @@ const useLightrangeGraph = () => {
const rawGraph = useGraphStore.use.rawGraph() const rawGraph = useGraphStore.use.rawGraph()
const sigmaGraph = useGraphStore.use.sigmaGraph() const sigmaGraph = useGraphStore.use.sigmaGraph()
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth() const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
const minDegree = useSettingsStore.use.graphMinDegree() const maxNodes = useSettingsStore.use.graphMaxNodes()
const isFetching = useGraphStore.use.isFetching() const isFetching = useGraphStore.use.isFetching()
const nodeToExpand = useGraphStore.use.nodeToExpand() const nodeToExpand = useGraphStore.use.nodeToExpand()
const nodeToPrune = useGraphStore.use.nodeToPrune() const nodeToPrune = useGraphStore.use.nodeToPrune()
// Use ref to track if data has been loaded and initial load // Use ref to track if data has been loaded and initial load
const dataLoadedRef = useRef(false) const dataLoadedRef = useRef(false)
const initialLoadRef = useRef(false) const initialLoadRef = useRef(false)
@@ -292,23 +360,37 @@ const useLightrangeGraph = () => {
// Use a local copy of the parameters // Use a local copy of the parameters
const currentQueryLabel = queryLabel const currentQueryLabel = queryLabel
const currentMaxQueryDepth = maxQueryDepth const currentMaxQueryDepth = maxQueryDepth
const currentMinDegree = minDegree const currentMaxNodes = maxNodes
// Declare a variable to store data promise // Declare a variable to store data promise
let dataPromise; let dataPromise: Promise<{ rawGraph: RawGraph | null; is_truncated: boolean | undefined } | null>;
// 1. If query label is not empty, use fetchGraph // 1. If query label is not empty, use fetchGraph
if (currentQueryLabel) { if (currentQueryLabel) {
dataPromise = fetchGraph(currentQueryLabel, currentMaxQueryDepth, currentMinDegree); dataPromise = fetchGraph(currentQueryLabel, currentMaxQueryDepth, currentMaxNodes);
} else { } else {
// 2. If query label is empty, set data to null // 2. If query label is empty, set data to null
console.log('Query label is empty, show empty graph') console.log('Query label is empty, show empty graph')
dataPromise = Promise.resolve(null); dataPromise = Promise.resolve({ rawGraph: null, is_truncated: false });
} }
// 3. Process data // 3. Process data
dataPromise.then((data) => { dataPromise.then((result) => {
const state = useGraphStore.getState() const state = useGraphStore.getState()
const data = result?.rawGraph;
// Assign colors based on entity_type *after* fetching
if (data && data.nodes) {
data.nodes.forEach(node => {
// Use entity_type instead of type
const nodeEntityType = node.properties?.entity_type as string | undefined;
node.color = getNodeColorByType(nodeEntityType);
});
}
if (result?.is_truncated) {
toast.info(t('graphPanel.dataIsTruncated', 'Graph data is truncated to Max Nodes'));
}
// Reset state // Reset state
state.reset() state.reset()
@@ -336,15 +418,23 @@ const useLightrangeGraph = () => {
// Still mark graph as empty for other logic // Still mark graph as empty for other logic
state.setGraphIsEmpty(true); state.setGraphIsEmpty(true);
// Only clear current label if it's not already empty // Check if the empty graph is due to 401 authentication error
if (currentQueryLabel) { const errorMessage = useBackendState.getState().message;
const isAuthError = errorMessage && errorMessage.includes('Authentication required');
// Only clear queryLabel if it's not an auth error and current label is not empty
if (!isAuthError && currentQueryLabel) {
useSettingsStore.getState().setQueryLabel(''); useSettingsStore.getState().setQueryLabel('');
} }
// Clear last successful query label to ensure labels are fetched next time // Only clear last successful query label if it's not an auth error
state.setLastSuccessfulQueryLabel(''); if (!isAuthError) {
state.setLastSuccessfulQueryLabel('');
} else {
console.log('Keep queryLabel for post-login reload');
}
console.log('Graph data is empty, created graph with empty graph node'); console.log(`Graph data is empty, created graph with empty graph node. Auth error: ${isAuthError}`);
} else { } else {
// Create and set new graph // Create and set new graph
const newSigmaGraph = createSigmaGraph(data); const newSigmaGraph = createSigmaGraph(data);
@@ -384,7 +474,7 @@ const useLightrangeGraph = () => {
state.setLastSuccessfulQueryLabel('') // Clear last successful query label on error state.setLastSuccessfulQueryLabel('') // Clear last successful query label on error
}) })
} }
}, [queryLabel, maxQueryDepth, minDegree, isFetching, t]) }, [queryLabel, maxQueryDepth, maxNodes, isFetching, t])
// Handle node expansion // Handle node expansion
useEffect(() => { useEffect(() => {
@@ -407,7 +497,7 @@ const useLightrangeGraph = () => {
} }
// Fetch the extended subgraph with depth 2 // Fetch the extended subgraph with depth 2
const extendedGraph = await queryGraphs(label, 2, 0); const extendedGraph = await queryGraphs(label, 2, 1000);
if (!extendedGraph || !extendedGraph.nodes || !extendedGraph.edges) { if (!extendedGraph || !extendedGraph.nodes || !extendedGraph.edges) {
console.error('Failed to fetch extended graph'); console.error('Failed to fetch extended graph');

View File

@@ -32,14 +32,24 @@
"authDisabled": "تم تعطيل المصادقة. استخدام وضع بدون تسجيل دخول.", "authDisabled": "تم تعطيل المصادقة. استخدام وضع بدون تسجيل دخول.",
"guestMode": "وضع بدون تسجيل دخول" "guestMode": "وضع بدون تسجيل دخول"
}, },
"common": {
"cancel": "إلغاء"
},
"documentPanel": { "documentPanel": {
"clearDocuments": { "clearDocuments": {
"button": "مسح", "button": "مسح",
"tooltip": "مسح المستندات", "tooltip": "مسح المستندات",
"title": "مسح المستندات", "title": "مسح المستندات",
"description": "سيؤدي هذا إلى إزالة جميع المستندات من النظام",
"warning": "تحذير: سيؤدي هذا الإجراء إلى حذف جميع المستندات بشكل دائم ولا يمكن التراجع عنه!",
"confirm": "هل تريد حقًا مسح جميع المستندات؟", "confirm": "هل تريد حقًا مسح جميع المستندات؟",
"confirmPrompt": "اكتب 'yes' لتأكيد هذا الإجراء",
"confirmPlaceholder": "اكتب yes للتأكيد",
"clearCache": "مسح كاش نموذج اللغة",
"confirmButton": "نعم", "confirmButton": "نعم",
"success": "تم مسح المستندات بنجاح", "success": "تم مسح المستندات بنجاح",
"cacheCleared": "تم مسح ذاكرة التخزين المؤقت بنجاح",
"cacheClearFailed": "فشل مسح ذاكرة التخزين المؤقت:\n{{error}}",
"failed": "فشل مسح المستندات:\n{{message}}", "failed": "فشل مسح المستندات:\n{{message}}",
"error": "فشل مسح المستندات:\n{{error}}" "error": "فشل مسح المستندات:\n{{error}}"
}, },
@@ -95,6 +105,7 @@
"metadata": "البيانات الوصفية" "metadata": "البيانات الوصفية"
}, },
"status": { "status": {
"all": "الكل",
"completed": "مكتمل", "completed": "مكتمل",
"processing": "قيد المعالجة", "processing": "قيد المعالجة",
"pending": "معلق", "pending": "معلق",
@@ -127,6 +138,11 @@
} }
}, },
"graphPanel": { "graphPanel": {
"dataIsTruncated": "تم اقتصار بيانات الرسم البياني على الحد الأقصى للعقد",
"statusDialog": {
"title": "إعدادات خادم LightRAG"
},
"legend": "المفتاح",
"sideBar": { "sideBar": {
"settings": { "settings": {
"settings": "الإعدادات", "settings": "الإعدادات",
@@ -139,9 +155,12 @@
"hideUnselectedEdges": "إخفاء الحواف غير المحددة", "hideUnselectedEdges": "إخفاء الحواف غير المحددة",
"edgeEvents": "أحداث الحافة", "edgeEvents": "أحداث الحافة",
"maxQueryDepth": "أقصى عمق للاستعلام", "maxQueryDepth": "أقصى عمق للاستعلام",
"minDegree": "الدرجة الدنيا", "maxNodes": "الحد الأقصى للعقد",
"maxLayoutIterations": "أقصى تكرارات التخطيط", "maxLayoutIterations": "أقصى تكرارات التخطيط",
"depth": "العمق", "resetToDefault": "إعادة التعيين إلى الافتراضي",
"edgeSizeRange": "نطاق حجم الحافة",
"depth": "D",
"max": "Max",
"degree": "الدرجة", "degree": "الدرجة",
"apiKey": "مفتاح واجهة برمجة التطبيقات", "apiKey": "مفتاح واجهة برمجة التطبيقات",
"enterYourAPIkey": "أدخل مفتاح واجهة برمجة التطبيقات الخاص بك", "enterYourAPIkey": "أدخل مفتاح واجهة برمجة التطبيقات الخاص بك",
@@ -171,6 +190,9 @@
"fullScreenControl": { "fullScreenControl": {
"fullScreen": "شاشة كاملة", "fullScreen": "شاشة كاملة",
"windowed": "نوافذ" "windowed": "نوافذ"
},
"legendControl": {
"toggleLegend": "تبديل المفتاح"
} }
}, },
"statusIndicator": { "statusIndicator": {

View File

@@ -32,14 +32,24 @@
"authDisabled": "Authentication is disabled. Using login free mode.", "authDisabled": "Authentication is disabled. Using login free mode.",
"guestMode": "Login Free" "guestMode": "Login Free"
}, },
"common": {
"cancel": "Cancel"
},
"documentPanel": { "documentPanel": {
"clearDocuments": { "clearDocuments": {
"button": "Clear", "button": "Clear",
"tooltip": "Clear documents", "tooltip": "Clear documents",
"title": "Clear Documents", "title": "Clear Documents",
"description": "This will remove all documents from the system",
"warning": "WARNING: This action will permanently delete all documents and cannot be undone!",
"confirm": "Do you really want to clear all documents?", "confirm": "Do you really want to clear all documents?",
"confirmPrompt": "Type 'yes' to confirm this action",
"confirmPlaceholder": "Type yes to confirm",
"clearCache": "Clear LLM cache",
"confirmButton": "YES", "confirmButton": "YES",
"success": "Documents cleared successfully", "success": "Documents cleared successfully",
"cacheCleared": "Cache cleared successfully",
"cacheClearFailed": "Failed to clear cache:\n{{error}}",
"failed": "Clear Documents Failed:\n{{message}}", "failed": "Clear Documents Failed:\n{{message}}",
"error": "Clear Documents Failed:\n{{error}}" "error": "Clear Documents Failed:\n{{error}}"
}, },
@@ -95,6 +105,7 @@
"metadata": "Metadata" "metadata": "Metadata"
}, },
"status": { "status": {
"all": "All",
"completed": "Completed", "completed": "Completed",
"processing": "Processing", "processing": "Processing",
"pending": "Pending", "pending": "Pending",
@@ -127,6 +138,11 @@
} }
}, },
"graphPanel": { "graphPanel": {
"dataIsTruncated": "Graph data is truncated to Max Nodes",
"statusDialog": {
"title": "LightRAG Server Settings"
},
"legend": "Legend",
"sideBar": { "sideBar": {
"settings": { "settings": {
"settings": "Settings", "settings": "Settings",
@@ -139,9 +155,12 @@
"hideUnselectedEdges": "Hide Unselected Edges", "hideUnselectedEdges": "Hide Unselected Edges",
"edgeEvents": "Edge Events", "edgeEvents": "Edge Events",
"maxQueryDepth": "Max Query Depth", "maxQueryDepth": "Max Query Depth",
"minDegree": "Minimum Degree", "maxNodes": "Max Nodes",
"maxLayoutIterations": "Max Layout Iterations", "maxLayoutIterations": "Max Layout Iterations",
"depth": "Depth", "resetToDefault": "Reset to default",
"edgeSizeRange": "Edge Size Range",
"depth": "D",
"max": "Max",
"degree": "Degree", "degree": "Degree",
"apiKey": "API Key", "apiKey": "API Key",
"enterYourAPIkey": "Enter your API key", "enterYourAPIkey": "Enter your API key",
@@ -171,6 +190,9 @@
"fullScreenControl": { "fullScreenControl": {
"fullScreen": "Full Screen", "fullScreen": "Full Screen",
"windowed": "Windowed" "windowed": "Windowed"
},
"legendControl": {
"toggleLegend": "Toggle Legend"
} }
}, },
"statusIndicator": { "statusIndicator": {

View File

@@ -32,14 +32,24 @@
"authDisabled": "L'authentification est désactivée. Utilisation du mode sans connexion.", "authDisabled": "L'authentification est désactivée. Utilisation du mode sans connexion.",
"guestMode": "Mode sans connexion" "guestMode": "Mode sans connexion"
}, },
"common": {
"cancel": "Annuler"
},
"documentPanel": { "documentPanel": {
"clearDocuments": { "clearDocuments": {
"button": "Effacer", "button": "Effacer",
"tooltip": "Effacer les documents", "tooltip": "Effacer les documents",
"title": "Effacer les documents", "title": "Effacer les documents",
"description": "Cette action supprimera tous les documents du système",
"warning": "ATTENTION : Cette action supprimera définitivement tous les documents et ne peut pas être annulée !",
"confirm": "Voulez-vous vraiment effacer tous les documents ?", "confirm": "Voulez-vous vraiment effacer tous les documents ?",
"confirmPrompt": "Tapez 'yes' pour confirmer cette action",
"confirmPlaceholder": "Tapez yes pour confirmer",
"clearCache": "Effacer le cache LLM",
"confirmButton": "OUI", "confirmButton": "OUI",
"success": "Documents effacés avec succès", "success": "Documents effacés avec succès",
"cacheCleared": "Cache effacé avec succès",
"cacheClearFailed": "Échec de l'effacement du cache :\n{{error}}",
"failed": "Échec de l'effacement des documents :\n{{message}}", "failed": "Échec de l'effacement des documents :\n{{message}}",
"error": "Échec de l'effacement des documents :\n{{error}}" "error": "Échec de l'effacement des documents :\n{{error}}"
}, },
@@ -95,6 +105,7 @@
"metadata": "Métadonnées" "metadata": "Métadonnées"
}, },
"status": { "status": {
"all": "Tous",
"completed": "Terminé", "completed": "Terminé",
"processing": "En traitement", "processing": "En traitement",
"pending": "En attente", "pending": "En attente",
@@ -127,6 +138,11 @@
} }
}, },
"graphPanel": { "graphPanel": {
"dataIsTruncated": "Les données du graphe sont tronquées au nombre maximum de nœuds",
"statusDialog": {
"title": "Paramètres du Serveur LightRAG"
},
"legend": "Légende",
"sideBar": { "sideBar": {
"settings": { "settings": {
"settings": "Paramètres", "settings": "Paramètres",
@@ -139,9 +155,12 @@
"hideUnselectedEdges": "Masquer les arêtes non sélectionnées", "hideUnselectedEdges": "Masquer les arêtes non sélectionnées",
"edgeEvents": "Événements des arêtes", "edgeEvents": "Événements des arêtes",
"maxQueryDepth": "Profondeur maximale de la requête", "maxQueryDepth": "Profondeur maximale de la requête",
"minDegree": "Degré minimum", "maxNodes": "Nombre maximum de nœuds",
"maxLayoutIterations": "Itérations maximales de mise en page", "maxLayoutIterations": "Itérations maximales de mise en page",
"depth": "Profondeur", "resetToDefault": "Réinitialiser par défaut",
"edgeSizeRange": "Plage de taille des arêtes",
"depth": "D",
"max": "Max",
"degree": "Degré", "degree": "Degré",
"apiKey": "Clé API", "apiKey": "Clé API",
"enterYourAPIkey": "Entrez votre clé API", "enterYourAPIkey": "Entrez votre clé API",
@@ -171,6 +190,9 @@
"fullScreenControl": { "fullScreenControl": {
"fullScreen": "Plein écran", "fullScreen": "Plein écran",
"windowed": "Fenêtré" "windowed": "Fenêtré"
},
"legendControl": {
"toggleLegend": "Basculer la légende"
} }
}, },
"statusIndicator": { "statusIndicator": {

View File

@@ -32,14 +32,24 @@
"authDisabled": "认证已禁用,使用无需登陆模式。", "authDisabled": "认证已禁用,使用无需登陆模式。",
"guestMode": "无需登陆" "guestMode": "无需登陆"
}, },
"common": {
"cancel": "取消"
},
"documentPanel": { "documentPanel": {
"clearDocuments": { "clearDocuments": {
"button": "清空", "button": "清空",
"tooltip": "清空文档", "tooltip": "清空文档",
"title": "清空文档", "title": "清空文档",
"description": "此操作将从系统中移除所有文档",
"warning": "警告:此操作将永久删除所有文档,无法恢复!",
"confirm": "确定要清空所有文档吗?", "confirm": "确定要清空所有文档吗?",
"confirmPrompt": "请输入 yes 确认操作",
"confirmPlaceholder": "输入 yes 确认",
"clearCache": "清空LLM缓存",
"confirmButton": "确定", "confirmButton": "确定",
"success": "文档清空成功", "success": "文档清空成功",
"cacheCleared": "缓存清空成功",
"cacheClearFailed": "清空缓存失败:\n{{error}}",
"failed": "清空文档失败:\n{{message}}", "failed": "清空文档失败:\n{{message}}",
"error": "清空文档失败:\n{{error}}" "error": "清空文档失败:\n{{error}}"
}, },
@@ -95,6 +105,7 @@
"metadata": "元数据" "metadata": "元数据"
}, },
"status": { "status": {
"all": "全部",
"completed": "已完成", "completed": "已完成",
"processing": "处理中", "processing": "处理中",
"pending": "等待中", "pending": "等待中",
@@ -127,6 +138,11 @@
} }
}, },
"graphPanel": { "graphPanel": {
"dataIsTruncated": "图数据已截断至最大返回节点数",
"statusDialog": {
"title": "LightRAG 服务器设置"
},
"legend": "图例",
"sideBar": { "sideBar": {
"settings": { "settings": {
"settings": "设置", "settings": "设置",
@@ -139,9 +155,12 @@
"hideUnselectedEdges": "隐藏未选中的边", "hideUnselectedEdges": "隐藏未选中的边",
"edgeEvents": "边事件", "edgeEvents": "边事件",
"maxQueryDepth": "最大查询深度", "maxQueryDepth": "最大查询深度",
"minDegree": "最小邻边数", "maxNodes": "最大返回节点数",
"maxLayoutIterations": "最大布局迭代次数", "maxLayoutIterations": "最大布局迭代次数",
"depth": "深度", "resetToDefault": "重置为默认值",
"edgeSizeRange": "边粗细范围",
"depth": "深",
"max": "Max",
"degree": "邻边", "degree": "邻边",
"apiKey": "API密钥", "apiKey": "API密钥",
"enterYourAPIkey": "输入您的API密钥", "enterYourAPIkey": "输入您的API密钥",
@@ -171,6 +190,9 @@
"fullScreenControl": { "fullScreenControl": {
"fullScreen": "全屏", "fullScreen": "全屏",
"windowed": "窗口" "windowed": "窗口"
},
"legendControl": {
"toggleLegend": "切换图例显示"
} }
}, },
"statusIndicator": { "statusIndicator": {

View File

@@ -32,7 +32,7 @@ class NavigationService {
// Reset backend state // Reset backend state
useBackendState.getState().clear(); useBackendState.getState().clear();
// Reset retrieval history while preserving other user preferences // Reset retrieval history message while preserving other user preferences
useSettingsStore.getState().setRetrievalHistory([]); useSettingsStore.getState().setRetrievalHistory([]);
// Clear authentication state // Clear authentication state

View File

@@ -77,6 +77,8 @@ interface GraphState {
graphIsEmpty: boolean graphIsEmpty: boolean
lastSuccessfulQueryLabel: string lastSuccessfulQueryLabel: string
typeColorMap: Map<string, string>
// Global flags to track data fetching attempts // Global flags to track data fetching attempts
graphDataFetchAttempted: boolean graphDataFetchAttempted: boolean
labelsFetchAttempted: boolean labelsFetchAttempted: boolean
@@ -136,6 +138,8 @@ const useGraphStoreBase = create<GraphState>()((set) => ({
sigmaInstance: null, sigmaInstance: null,
allDatabaseLabels: ['*'], allDatabaseLabels: ['*'],
typeColorMap: new Map<string, string>(),
searchEngine: null, searchEngine: null,
setGraphIsEmpty: (isEmpty: boolean) => set({ graphIsEmpty: isEmpty }), setGraphIsEmpty: (isEmpty: boolean) => set({ graphIsEmpty: isEmpty }),
@@ -166,7 +170,6 @@ const useGraphStoreBase = create<GraphState>()((set) => ({
searchEngine: null, searchEngine: null,
moveToSelectedNode: false, moveToSelectedNode: false,
graphIsEmpty: false graphIsEmpty: false
// Do not reset lastSuccessfulQueryLabel here as it's used to track query history
}); });
}, },
@@ -199,6 +202,8 @@ const useGraphStoreBase = create<GraphState>()((set) => ({
setSigmaInstance: (instance: any) => set({ sigmaInstance: instance }), setSigmaInstance: (instance: any) => set({ sigmaInstance: instance }),
setTypeColorMap: (typeColorMap: Map<string, string>) => set({ typeColorMap }),
setSearchEngine: (engine: MiniSearch | null) => set({ searchEngine: engine }), setSearchEngine: (engine: MiniSearch | null) => set({ searchEngine: engine }),
resetSearchEngine: () => set({ searchEngine: null }), resetSearchEngine: () => set({ searchEngine: null }),

View File

@@ -16,6 +16,8 @@ interface SettingsState {
// Graph viewer settings // Graph viewer settings
showPropertyPanel: boolean showPropertyPanel: boolean
showNodeSearchBar: boolean showNodeSearchBar: boolean
showLegend: boolean
setShowLegend: (show: boolean) => void
showNodeLabel: boolean showNodeLabel: boolean
enableNodeDrag: boolean enableNodeDrag: boolean
@@ -24,11 +26,17 @@ interface SettingsState {
enableHideUnselectedEdges: boolean enableHideUnselectedEdges: boolean
enableEdgeEvents: boolean enableEdgeEvents: boolean
minEdgeSize: number
setMinEdgeSize: (size: number) => void
maxEdgeSize: number
setMaxEdgeSize: (size: number) => void
graphQueryMaxDepth: number graphQueryMaxDepth: number
setGraphQueryMaxDepth: (depth: number) => void setGraphQueryMaxDepth: (depth: number) => void
graphMinDegree: number graphMaxNodes: number
setGraphMinDegree: (degree: number) => void setGraphMaxNodes: (nodes: number) => void
graphLayoutMaxIterations: number graphLayoutMaxIterations: number
setGraphLayoutMaxIterations: (iterations: number) => void setGraphLayoutMaxIterations: (iterations: number) => void
@@ -68,6 +76,7 @@ const useSettingsStoreBase = create<SettingsState>()(
language: 'en', language: 'en',
showPropertyPanel: true, showPropertyPanel: true,
showNodeSearchBar: true, showNodeSearchBar: true,
showLegend: false,
showNodeLabel: true, showNodeLabel: true,
enableNodeDrag: true, enableNodeDrag: true,
@@ -76,8 +85,11 @@ const useSettingsStoreBase = create<SettingsState>()(
enableHideUnselectedEdges: true, enableHideUnselectedEdges: true,
enableEdgeEvents: false, enableEdgeEvents: false,
minEdgeSize: 1,
maxEdgeSize: 1,
graphQueryMaxDepth: 3, graphQueryMaxDepth: 3,
graphMinDegree: 0, graphMaxNodes: 1000,
graphLayoutMaxIterations: 15, graphLayoutMaxIterations: 15,
queryLabel: defaultQueryLabel, queryLabel: defaultQueryLabel,
@@ -130,7 +142,11 @@ const useSettingsStoreBase = create<SettingsState>()(
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }), setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
setGraphMinDegree: (degree: number) => set({ graphMinDegree: degree }), setGraphMaxNodes: (nodes: number) => set({ graphMaxNodes: nodes }),
setMinEdgeSize: (size: number) => set({ minEdgeSize: size }),
setMaxEdgeSize: (size: number) => set({ maxEdgeSize: size }),
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }), setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
@@ -145,12 +161,13 @@ const useSettingsStoreBase = create<SettingsState>()(
querySettings: { ...state.querySettings, ...settings } querySettings: { ...state.querySettings, ...settings }
})), })),
setShowFileName: (show: boolean) => set({ showFileName: show }) setShowFileName: (show: boolean) => set({ showFileName: show }),
setShowLegend: (show: boolean) => set({ showLegend: show })
}), }),
{ {
name: 'settings-storage', name: 'settings-storage',
storage: createJSONStorage(() => localStorage), storage: createJSONStorage(() => localStorage),
version: 9, version: 11,
migrate: (state: any, version: number) => { migrate: (state: any, version: number) => {
if (version < 2) { if (version < 2) {
state.showEdgeLabel = false state.showEdgeLabel = false
@@ -196,6 +213,14 @@ const useSettingsStoreBase = create<SettingsState>()(
if (version < 9) { if (version < 9) {
state.showFileName = false state.showFileName = false
} }
if (version < 10) {
delete state.graphMinDegree // 删除废弃参数
state.graphMaxNodes = 1000 // 添加新参数
}
if (version < 11) {
state.minEdgeSize = 1
state.maxEdgeSize = 1
}
return state return state
} }
} }

View File

@@ -22,10 +22,13 @@ interface AuthState {
coreVersion: string | null; coreVersion: string | null;
apiVersion: string | null; apiVersion: string | null;
username: string | null; // login username username: string | null; // login username
webuiTitle: string | null; // Custom title
webuiDescription: string | null; // Title description
login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null) => void; login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null, webuiTitle?: string | null, webuiDescription?: string | null) => void;
logout: () => void; logout: () => void;
setVersion: (coreVersion: string | null, apiVersion: string | null) => void; setVersion: (coreVersion: string | null, apiVersion: string | null) => void;
setCustomTitle: (webuiTitle: string | null, webuiDescription: string | null) => void;
} }
const useBackendStateStoreBase = create<BackendState>()((set) => ({ const useBackendStateStoreBase = create<BackendState>()((set) => ({
@@ -47,6 +50,14 @@ const useBackendStateStoreBase = create<BackendState>()((set) => ({
); );
} }
// Update custom title information if health check returns it
if ('webui_title' in health || 'webui_description' in health) {
useAuthStore.getState().setCustomTitle(
'webui_title' in health ? (health.webui_title ?? null) : null,
'webui_description' in health ? (health.webui_description ?? null) : null
);
}
set({ set({
health: true, health: true,
message: null, message: null,
@@ -107,10 +118,12 @@ const isGuestToken = (token: string): boolean => {
return payload.role === 'guest'; return payload.role === 'guest';
}; };
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null } => { const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null } => {
const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION'); const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION'); const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
const username = token ? getUsernameFromToken(token) : null; const username = token ? getUsernameFromToken(token) : null;
if (!token) { if (!token) {
@@ -120,6 +133,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
coreVersion: coreVersion, coreVersion: coreVersion,
apiVersion: apiVersion, apiVersion: apiVersion,
username: null, username: null,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
}; };
} }
@@ -129,6 +144,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
coreVersion: coreVersion, coreVersion: coreVersion,
apiVersion: apiVersion, apiVersion: apiVersion,
username: username, username: username,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
}; };
}; };
@@ -142,8 +159,10 @@ export const useAuthStore = create<AuthState>(set => {
coreVersion: initialState.coreVersion, coreVersion: initialState.coreVersion,
apiVersion: initialState.apiVersion, apiVersion: initialState.apiVersion,
username: initialState.username, username: initialState.username,
webuiTitle: initialState.webuiTitle,
webuiDescription: initialState.webuiDescription,
login: (token, isGuest = false, coreVersion = null, apiVersion = null) => { login: (token, isGuest = false, coreVersion = null, apiVersion = null, webuiTitle = null, webuiDescription = null) => {
localStorage.setItem('LIGHTRAG-API-TOKEN', token); localStorage.setItem('LIGHTRAG-API-TOKEN', token);
if (coreVersion) { if (coreVersion) {
@@ -153,6 +172,18 @@ export const useAuthStore = create<AuthState>(set => {
localStorage.setItem('LIGHTRAG-API-VERSION', apiVersion); localStorage.setItem('LIGHTRAG-API-VERSION', apiVersion);
} }
if (webuiTitle) {
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle);
} else {
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE');
}
if (webuiDescription) {
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription);
} else {
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION');
}
const username = getUsernameFromToken(token); const username = getUsernameFromToken(token);
set({ set({
isAuthenticated: true, isAuthenticated: true,
@@ -160,6 +191,8 @@ export const useAuthStore = create<AuthState>(set => {
username: username, username: username,
coreVersion: coreVersion, coreVersion: coreVersion,
apiVersion: apiVersion, apiVersion: apiVersion,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
}); });
}, },
@@ -168,6 +201,8 @@ export const useAuthStore = create<AuthState>(set => {
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION'); const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION'); const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
set({ set({
isAuthenticated: false, isAuthenticated: false,
@@ -175,6 +210,8 @@ export const useAuthStore = create<AuthState>(set => {
username: null, username: null,
coreVersion: coreVersion, coreVersion: coreVersion,
apiVersion: apiVersion, apiVersion: apiVersion,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
}); });
}, },
@@ -192,6 +229,27 @@ export const useAuthStore = create<AuthState>(set => {
coreVersion: coreVersion, coreVersion: coreVersion,
apiVersion: apiVersion apiVersion: apiVersion
}); });
},
setCustomTitle: (webuiTitle, webuiDescription) => {
// Update localStorage
if (webuiTitle) {
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle);
} else {
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE');
}
if (webuiDescription) {
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription);
} else {
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION');
}
// Update state
set({
webuiTitle: webuiTitle,
webuiDescription: webuiDescription
});
} }
}; };
}); });

440
tests/test_graph_storage.py Normal file
View File

@@ -0,0 +1,440 @@
#!/usr/bin/env python
"""
通用图存储测试程序
该程序根据.env中的LIGHTRAG_GRAPH_STORAGE配置选择使用的图存储类型
并对其进行基本操作和高级操作的测试。
支持的图存储类型包括:
- NetworkXStorage
- Neo4JStorage
- PGGraphStorage
"""
import asyncio
import os
import sys
import importlib
import numpy as np
from dotenv import load_dotenv
from ascii_colors import ASCIIColors
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lightrag.types import KnowledgeGraph
from lightrag.kg import (
STORAGE_IMPLEMENTATIONS,
STORAGE_ENV_REQUIREMENTS,
STORAGES,
verify_storage_implementation,
)
from lightrag.kg.shared_storage import initialize_share_data
# 模拟的嵌入函数,返回随机向量
async def mock_embedding_func(texts):
return np.random.rand(len(texts), 10) # 返回10维随机向量
def check_env_file():
"""
检查.env文件是否存在如果不存在则发出警告
返回True表示应该继续执行False表示应该退出
"""
if not os.path.exists(".env"):
warning_msg = "警告: 当前目录中没有找到.env文件这可能会影响存储配置的加载。"
ASCIIColors.yellow(warning_msg)
# 检查是否在交互式终端中运行
if sys.stdin.isatty():
response = input("是否继续执行? (yes/no): ")
if response.lower() != "yes":
ASCIIColors.red("测试程序已取消")
return False
return True
async def initialize_graph_storage():
"""
根据环境变量初始化相应的图存储实例
返回初始化的存储实例
"""
# 从环境变量中获取图存储类型
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
# 验证存储类型是否有效
try:
verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
except ValueError as e:
ASCIIColors.red(f"错误: {str(e)}")
ASCIIColors.yellow(
f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
return None
# 检查所需的环境变量
required_env_vars = STORAGE_ENV_REQUIREMENTS.get(graph_storage_type, [])
missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
if missing_env_vars:
ASCIIColors.red(
f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}"
)
return None
# 动态导入相应的模块
module_path = STORAGES.get(graph_storage_type)
if not module_path:
ASCIIColors.red(f"错误: 未找到 {graph_storage_type} 的模块路径")
return None
try:
module = importlib.import_module(module_path, package="lightrag")
storage_class = getattr(module, graph_storage_type)
except (ImportError, AttributeError) as e:
ASCIIColors.red(f"错误: 导入 {graph_storage_type} 失败: {str(e)}")
return None
# 初始化存储实例
global_config = {
"embedding_batch_num": 10, # 批处理大小
"vector_db_storage_cls_kwargs": {
"cosine_better_than_threshold": 0.5 # 余弦相似度阈值
},
"working_dir": os.environ.get("WORKING_DIR", "./rag_storage"), # 工作目录
}
# 如果使用 NetworkXStorage需要先初始化 shared_storage
if graph_storage_type == "NetworkXStorage":
initialize_share_data() # 使用单进程模式
try:
storage = storage_class(
namespace="test_graph",
global_config=global_config,
embedding_func=mock_embedding_func,
)
# 初始化连接
await storage.initialize()
return storage
except Exception as e:
ASCIIColors.red(f"错误: 初始化 {graph_storage_type} 失败: {str(e)}")
return None
async def test_graph_basic(storage):
"""
测试图数据库的基本操作:
1. 使用 upsert_node 插入两个节点
2. 使用 upsert_edge 插入一条连接两个节点的边
3. 使用 get_node 读取一个节点
4. 使用 get_edge 读取一条边
"""
try:
# 清理之前的测试数据
print("清理之前的测试数据...")
await storage.drop()
# 1. 插入第一个节点
node1_id = "人工智能"
node1_data = {
"entity_id": node1_id,
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域",
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# 2. 插入第二个节点
node2_id = "机器学习"
node2_data = {
"entity_id": node2_id,
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域",
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# 3. 插入连接边
edge_data = {
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域",
}
print(f"插入边: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge_data)
# 4. 读取节点属性
print(f"读取节点属性: {node1_id}")
node1_props = await storage.get_node(node1_id)
if node1_props:
print(f"成功读取节点属性: {node1_id}")
print(f"节点描述: {node1_props.get('description', '无描述')}")
print(f"节点类型: {node1_props.get('entity_type', '无类型')}")
print(f"节点关键词: {node1_props.get('keywords', '无关键词')}")
# 验证返回的属性是否正确
assert (
node1_props.get("entity_id") == node1_id
), f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
assert (
node1_props.get("description") == node1_data["description"]
), "节点描述不匹配"
assert (
node1_props.get("entity_type") == node1_data["entity_type"]
), "节点类型不匹配"
else:
print(f"读取节点属性失败: {node1_id}")
assert False, f"未能读取节点属性: {node1_id}"
# 5. 读取边属性
print(f"读取边属性: {node1_id} -> {node2_id}")
edge_props = await storage.get_edge(node1_id, node2_id)
if edge_props:
print(f"成功读取边属性: {node1_id} -> {node2_id}")
print(f"边关系: {edge_props.get('relationship', '无关系')}")
print(f"边描述: {edge_props.get('description', '无描述')}")
print(f"边权重: {edge_props.get('weight', '无权重')}")
# 验证返回的属性是否正确
assert (
edge_props.get("relationship") == edge_data["relationship"]
), "边关系不匹配"
assert (
edge_props.get("description") == edge_data["description"]
), "边描述不匹配"
assert edge_props.get("weight") == edge_data["weight"], "边权重不匹配"
else:
print(f"读取边属性失败: {node1_id} -> {node2_id}")
assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
print("基本测试完成,数据已保留在数据库中")
return True
except Exception as e:
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
return False
async def test_graph_advanced(storage):
"""
测试图数据库的高级操作:
1. 使用 node_degree 获取节点的度数
2. 使用 edge_degree 获取边的度数
3. 使用 get_node_edges 获取节点的所有边
4. 使用 get_all_labels 获取所有标签
5. 使用 get_knowledge_graph 获取知识图谱
6. 使用 delete_node 删除节点
7. 使用 remove_nodes 批量删除节点
8. 使用 remove_edges 删除边
9. 使用 drop 清理数据
"""
try:
# 清理之前的测试数据
print("清理之前的测试数据...\n")
await storage.drop()
# 1. 插入测试数据
# 插入节点1: 人工智能
node1_id = "人工智能"
node1_data = {
"entity_id": node1_id,
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域",
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
# 插入节点2: 机器学习
node2_id = "机器学习"
node2_data = {
"entity_id": node2_id,
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域",
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
# 插入节点3: 深度学习
node3_id = "深度学习"
node3_data = {
"entity_id": node3_id,
"description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
"keywords": "神经网络,CNN,RNN",
"entity_type": "技术领域",
}
print(f"插入节点3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
# 插入边1: 人工智能 -> 机器学习
edge1_data = {
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域",
}
print(f"插入边1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
# 插入边2: 机器学习 -> 深度学习
edge2_data = {
"relationship": "包含",
"weight": 1.0,
"description": "机器学习领域包含深度学习这个子领域",
}
print(f"插入边2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
# 2. 测试 node_degree - 获取节点的度数
print(f"== 测试 node_degree: {node1_id}")
node1_degree = await storage.node_degree(node1_id)
print(f"节点 {node1_id} 的度数: {node1_degree}")
assert node1_degree == 1, f"节点 {node1_id} 的度数应为1实际为 {node1_degree}"
# 3. 测试 edge_degree - 获取边的度数
print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
edge_degree = await storage.edge_degree(node1_id, node2_id)
print(f"{node1_id} -> {node2_id} 的度数: {edge_degree}")
assert (
edge_degree == 3
), f"{node1_id} -> {node2_id} 的度数应为2实际为 {edge_degree}"
# 4. 测试 get_node_edges - 获取节点的所有边
print(f"== 测试 get_node_edges: {node2_id}")
node2_edges = await storage.get_node_edges(node2_id)
print(f"节点 {node2_id} 的所有边: {node2_edges}")
assert (
len(node2_edges) == 2
), f"节点 {node2_id} 应有2条边实际有 {len(node2_edges)}"
# 5. 测试 get_all_labels - 获取所有标签
print("== 测试 get_all_labels")
all_labels = await storage.get_all_labels()
print(f"所有标签: {all_labels}")
assert len(all_labels) == 3, f"应有3个标签实际有 {len(all_labels)}"
assert node1_id in all_labels, f"{node1_id} 应在标签列表中"
assert node2_id in all_labels, f"{node2_id} 应在标签列表中"
assert node3_id in all_labels, f"{node3_id} 应在标签列表中"
# 6. 测试 get_knowledge_graph - 获取知识图谱
print("== 测试 get_knowledge_graph")
kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
print(f"知识图谱节点数: {len(kg.nodes)}")
print(f"知识图谱边数: {len(kg.edges)}")
assert isinstance(kg, KnowledgeGraph), "返回结果应为 KnowledgeGraph 类型"
assert len(kg.nodes) == 3, f"知识图谱应有3个节点实际有 {len(kg.nodes)}"
assert len(kg.edges) == 2, f"知识图谱应有2条边实际有 {len(kg.edges)}"
# 7. 测试 delete_node - 删除节点
print(f"== 测试 delete_node: {node3_id}")
await storage.delete_node(node3_id)
node3_props = await storage.get_node(node3_id)
print(f"删除后查询节点属性 {node3_id}: {node3_props}")
assert node3_props is None, f"节点 {node3_id} 应已被删除"
# 重新插入节点3用于后续测试
await storage.upsert_node(node3_id, node3_data)
await storage.upsert_edge(node2_id, node3_id, edge2_data)
# 8. 测试 remove_edges - 删除边
print(f"== 测试 remove_edges: {node2_id} -> {node3_id}")
await storage.remove_edges([(node2_id, node3_id)])
edge_props = await storage.get_edge(node2_id, node3_id)
print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}")
assert edge_props is None, f"{node2_id} -> {node3_id} 应已被删除"
# 9. 测试 remove_nodes - 批量删除节点
print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]")
await storage.remove_nodes([node2_id, node3_id])
node2_props = await storage.get_node(node2_id)
node3_props = await storage.get_node(node3_id)
print(f"删除后查询节点属性 {node2_id}: {node2_props}")
print(f"删除后查询节点属性 {node3_id}: {node3_props}")
assert node2_props is None, f"节点 {node2_id} 应已被删除"
assert node3_props is None, f"节点 {node3_id} 应已被删除"
# 10. 测试 drop - 清理数据
print("== 测试 drop")
result = await storage.drop()
print(f"清理结果: {result}")
assert (
result["status"] == "success"
), f"清理应成功,实际状态为 {result['status']}"
# 验证清理结果
all_labels = await storage.get_all_labels()
print(f"清理后的所有标签: {all_labels}")
assert len(all_labels) == 0, f"清理后应没有标签,实际有 {len(all_labels)}"
print("\n高级测试完成")
return True
except Exception as e:
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
return False
async def main():
"""主函数"""
# 显示程序标题
ASCIIColors.cyan("""
╔══════════════════════════════════════════════════════════════╗
║ 通用图存储测试程序 ║
╚══════════════════════════════════════════════════════════════╝
""")
# 检查.env文件
if not check_env_file():
return
# 加载环境变量
load_dotenv(dotenv_path=".env", override=False)
# 获取图存储类型
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
ASCIIColors.magenta(f"\n当前配置的图存储类型: {graph_storage_type}")
ASCIIColors.white(
f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
# 初始化存储实例
storage = await initialize_graph_storage()
if not storage:
ASCIIColors.red("初始化存储实例失败,测试程序退出")
return
try:
# 显示测试选项
ASCIIColors.yellow("\n请选择测试类型:")
ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
ASCIIColors.white("3. 全部测试")
choice = input("\n请输入选项 (1/2/3): ")
if choice == "1":
await test_graph_basic(storage)
elif choice == "2":
await test_graph_advanced(storage)
elif choice == "3":
ASCIIColors.cyan("\n=== 开始基本测试 ===")
basic_result = await test_graph_basic(storage)
if basic_result:
ASCIIColors.cyan("\n=== 开始高级测试 ===")
await test_graph_advanced(storage)
else:
ASCIIColors.red("无效的选项")
finally:
# 关闭连接
if storage:
await storage.finalize()
ASCIIColors.green("\n存储连接已关闭")
if __name__ == "__main__":
asyncio.run(main())