Merge branch 'main' into main
This commit is contained in:
56
README-zh.md
56
README-zh.md
@@ -11,7 +11,6 @@
|
||||
- [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
||||
- [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
|
||||
- [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
|
||||
- [X] [2024.11.12]🎯📢LightRAG现在支持[Oracle Database 23ai的所有存储类型(KV、向量和图)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py)。
|
||||
- [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
||||
- [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
|
||||
- [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
|
||||
@@ -410,6 +409,54 @@ if __name__ == "__main__":
|
||||
|
||||
</details>
|
||||
|
||||
### Token统计功能
|
||||
<details>
|
||||
<summary> <b>概述和使用</b> </summary>
|
||||
|
||||
LightRAG提供了TokenTracker工具来跟踪和管理大模型的token消耗。这个功能对于控制API成本和优化性能特别有用。
|
||||
|
||||
#### 使用方法
|
||||
|
||||
```python
|
||||
from lightrag.utils import TokenTracker
|
||||
|
||||
# 创建TokenTracker实例
|
||||
token_tracker = TokenTracker()
|
||||
|
||||
# 方法1:使用上下文管理器(推荐)
|
||||
# 适用于需要自动跟踪token使用的场景
|
||||
with token_tracker:
|
||||
result1 = await llm_model_func("你的问题1")
|
||||
result2 = await llm_model_func("你的问题2")
|
||||
|
||||
# 方法2:手动添加token使用记录
|
||||
# 适用于需要更精细控制token统计的场景
|
||||
token_tracker.reset()
|
||||
|
||||
rag.insert()
|
||||
|
||||
rag.query("你的问题1", param=QueryParam(mode="naive"))
|
||||
rag.query("你的问题2", param=QueryParam(mode="mix"))
|
||||
|
||||
# 显示总token使用量(包含插入和查询操作)
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
```
|
||||
|
||||
#### 使用建议
|
||||
- 在长会话或批量操作中使用上下文管理器,可以自动跟踪所有token消耗
|
||||
- 对于需要分段统计的场景,使用手动模式并适时调用reset()
|
||||
- 定期检查token使用情况,有助于及时发现异常消耗
|
||||
- 在开发测试阶段积极使用此功能,以便优化生产环境的成本
|
||||
|
||||
#### 实际应用示例
|
||||
您可以参考以下示例来实现token统计:
|
||||
- `examples/lightrag_gemini_track_token_demo.py`:使用Google Gemini模型的token统计示例
|
||||
- `examples/lightrag_siliconcloud_track_token_demo.py`:使用SiliconCloud模型的token统计示例
|
||||
|
||||
这些示例展示了如何在不同模型和场景下有效地使用TokenTracker功能。
|
||||
|
||||
</details>
|
||||
|
||||
### 对话历史
|
||||
|
||||
LightRAG现在通过对话历史功能支持多轮对话。以下是使用方法:
|
||||
@@ -1037,9 +1084,10 @@ rag.clear_cache(modes=["local"])
|
||||
| **参数** | **类型** | **说明** | **默认值** |
|
||||
|--------------|----------|-----------------|-------------|
|
||||
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
|
||||
| **kv_storage** | `str` | 文档和文本块的存储类型。支持的类型:`JsonKVStorage`、`OracleKVStorage` | `JsonKVStorage` |
|
||||
| **vector_storage** | `str` | 嵌入向量的存储类型。支持的类型:`NanoVectorDBStorage`、`OracleVectorDBStorage` | `NanoVectorDBStorage` |
|
||||
| **graph_storage** | `str` | 图边和节点的存储类型。支持的类型:`NetworkXStorage`、`Neo4JStorage`、`OracleGraphStorage` | `NetworkXStorage` |
|
||||
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
||||
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
||||
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
|
||||
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
||||
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
||||
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
||||
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
|
||||
|
59
README.md
59
README.md
@@ -41,7 +41,6 @@
|
||||
- [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
||||
- [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
|
||||
- [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
|
||||
- [X] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
|
||||
- [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
||||
- [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
|
||||
- [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
|
||||
@@ -443,6 +442,55 @@ if __name__ == "__main__":
|
||||
|
||||
</details>
|
||||
|
||||
### Token Usage Tracking
|
||||
|
||||
<details>
|
||||
<summary> <b>Overview and Usage</b> </summary>
|
||||
|
||||
LightRAG provides a TokenTracker tool to monitor and manage token consumption by large language models. This feature is particularly useful for controlling API costs and optimizing performance.
|
||||
|
||||
#### Usage
|
||||
|
||||
```python
|
||||
from lightrag.utils import TokenTracker
|
||||
|
||||
# Create TokenTracker instance
|
||||
token_tracker = TokenTracker()
|
||||
|
||||
# Method 1: Using context manager (Recommended)
|
||||
# Suitable for scenarios requiring automatic token usage tracking
|
||||
with token_tracker:
|
||||
result1 = await llm_model_func("your question 1")
|
||||
result2 = await llm_model_func("your question 2")
|
||||
|
||||
# Method 2: Manually adding token usage records
|
||||
# Suitable for scenarios requiring more granular control over token statistics
|
||||
token_tracker.reset()
|
||||
|
||||
rag.insert()
|
||||
|
||||
rag.query("your question 1", param=QueryParam(mode="naive"))
|
||||
rag.query("your question 2", param=QueryParam(mode="mix"))
|
||||
|
||||
# Display total token usage (including insert and query operations)
|
||||
print("Token usage:", token_tracker.get_usage())
|
||||
```
|
||||
|
||||
#### Usage Tips
|
||||
- Use context managers for long sessions or batch operations to automatically track all token consumption
|
||||
- For scenarios requiring segmented statistics, use manual mode and call reset() when appropriate
|
||||
- Regular checking of token usage helps detect abnormal consumption early
|
||||
- Actively use this feature during development and testing to optimize production costs
|
||||
|
||||
#### Practical Examples
|
||||
You can refer to these examples for implementing token tracking:
|
||||
- `examples/lightrag_gemini_track_token_demo.py`: Token tracking example using Google Gemini model
|
||||
- `examples/lightrag_siliconcloud_track_token_demo.py`: Token tracking example using SiliconCloud model
|
||||
|
||||
These examples demonstrate how to effectively use the TokenTracker feature with different models and scenarios.
|
||||
|
||||
</details>
|
||||
|
||||
### Conversation History Support
|
||||
|
||||
|
||||
@@ -607,7 +655,7 @@ The `apipeline_enqueue_documents` and `apipeline_process_enqueue_documents` func
|
||||
|
||||
This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing.
|
||||
|
||||
And using a routine to process news documents.
|
||||
And using a routine to process new documents.
|
||||
|
||||
```python
|
||||
rag = LightRAG(..)
|
||||
@@ -1096,9 +1144,10 @@ Valid modes are:
|
||||
| **Parameter** | **Type** | **Explanation** | **Default** |
|
||||
|--------------|----------|-----------------|-------------|
|
||||
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
|
||||
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
|
||||
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
|
||||
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
|
||||
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
||||
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
||||
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
|
||||
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
||||
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
||||
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
||||
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
|
||||
|
@@ -13,23 +13,6 @@ uri=redis://localhost:6379/1
|
||||
[qdrant]
|
||||
uri = http://localhost:16333
|
||||
|
||||
[oracle]
|
||||
dsn = localhost:1521/XEPDB1
|
||||
user = your_username
|
||||
password = your_password
|
||||
config_dir = /path/to/oracle/config
|
||||
wallet_location = /path/to/wallet # 可选
|
||||
wallet_password = your_wallet_password # 可选
|
||||
workspace = default # 可选,默认为default
|
||||
|
||||
[tidb]
|
||||
host = localhost
|
||||
port = 4000
|
||||
user = your_username
|
||||
password = your_password
|
||||
database = your_database
|
||||
workspace = default # 可选,默认为default
|
||||
|
||||
[postgres]
|
||||
host = localhost
|
||||
port = 5432
|
||||
|
43
env.example
43
env.example
@@ -4,11 +4,9 @@
|
||||
# HOST=0.0.0.0
|
||||
# PORT=9621
|
||||
# WORKERS=2
|
||||
### separating data from difference Lightrag instances
|
||||
# NAMESPACE_PREFIX=lightrag
|
||||
### Max nodes return from grap retrieval
|
||||
# MAX_GRAPH_NODES=1000
|
||||
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
||||
WEBUI_TITLE='Graph RAG Engine'
|
||||
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
||||
|
||||
### Optional SSL Configuration
|
||||
# SSL=true
|
||||
@@ -22,6 +20,9 @@
|
||||
### Ollama Emulating Model Tag
|
||||
# OLLAMA_EMULATING_MODEL_TAG=latest
|
||||
|
||||
### Max nodes return from grap retrieval
|
||||
# MAX_GRAPH_NODES=1000
|
||||
|
||||
### Logging level
|
||||
# LOG_LEVEL=INFO
|
||||
# VERBOSE=False
|
||||
@@ -110,23 +111,13 @@ LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage
|
||||
LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
|
||||
LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
|
||||
|
||||
### Oracle Database Configuration
|
||||
ORACLE_DSN=localhost:1521/XEPDB1
|
||||
ORACLE_USER=your_username
|
||||
ORACLE_PASSWORD='your_password'
|
||||
ORACLE_CONFIG_DIR=/path/to/oracle/config
|
||||
#ORACLE_WALLET_LOCATION=/path/to/wallet
|
||||
#ORACLE_WALLET_PASSWORD='your_password'
|
||||
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
||||
#ORACLE_WORKSPACE=default
|
||||
|
||||
### TiDB Configuration
|
||||
TIDB_HOST=localhost
|
||||
TIDB_PORT=4000
|
||||
TIDB_USER=your_username
|
||||
TIDB_PASSWORD='your_password'
|
||||
TIDB_DATABASE=your_database
|
||||
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
||||
### TiDB Configuration (Deprecated)
|
||||
# TIDB_HOST=localhost
|
||||
# TIDB_PORT=4000
|
||||
# TIDB_USER=your_username
|
||||
# TIDB_PASSWORD='your_password'
|
||||
# TIDB_DATABASE=your_database
|
||||
### separating all data from difference Lightrag instances(deprecating)
|
||||
# TIDB_WORKSPACE=default
|
||||
|
||||
### PostgreSQL Configuration
|
||||
@@ -135,7 +126,7 @@ POSTGRES_PORT=5432
|
||||
POSTGRES_USER=your_username
|
||||
POSTGRES_PASSWORD='your_password'
|
||||
POSTGRES_DATABASE=your_database
|
||||
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
||||
### separating all data from difference Lightrag instances(deprecating)
|
||||
# POSTGRES_WORKSPACE=default
|
||||
|
||||
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
||||
@@ -145,8 +136,8 @@ AGE_POSTGRES_PASSWORD=
|
||||
AGE_POSTGRES_HOST=
|
||||
# AGE_POSTGRES_PORT=8529
|
||||
|
||||
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
||||
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
||||
### AGE_GRAPH_NAME is precated
|
||||
# AGE_GRAPH_NAME=lightrag
|
||||
|
||||
### Neo4j Configuration
|
||||
@@ -157,7 +148,7 @@ NEO4J_PASSWORD='your_password'
|
||||
### MongoDB Configuration
|
||||
MONGO_URI=mongodb://root:root@localhost:27017/
|
||||
MONGO_DATABASE=LightRAG
|
||||
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
||||
### separating all data from difference Lightrag instances(deprecating)
|
||||
# MONGODB_GRAPH=false
|
||||
|
||||
### Milvus Configuration
|
||||
@@ -177,7 +168,9 @@ REDIS_URI=redis://localhost:6379
|
||||
### For JWT Auth
|
||||
# AUTH_ACCOUNTS='admin:admin123,user1:pass456'
|
||||
# TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
|
||||
# TOKEN_EXPIRE_HOURS=4
|
||||
# TOKEN_EXPIRE_HOURS=48
|
||||
# GUEST_TOKEN_EXPIRE_HOURS=24
|
||||
# JWT_ALGORITHM=HS256
|
||||
|
||||
### API-Key to access LightRAG Server API
|
||||
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||
|
@@ -1,188 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.ollama import ollama_embed, ollama_model_complete
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
import aiofiles
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
DEFAULT_RAG_DIR = "index_default"
|
||||
|
||||
DEFAULT_INPUT_FILE = "book.txt"
|
||||
INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}")
|
||||
print(f"INPUT_FILE: {INPUT_FILE}")
|
||||
|
||||
# Configure working directory
|
||||
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
||||
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def init():
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=ollama_model_complete,
|
||||
llm_model_name="gemma2:9b",
|
||||
llm_model_max_async=4,
|
||||
llm_model_max_token_size=8192,
|
||||
llm_model_kwargs={
|
||||
"host": "http://localhost:11434",
|
||||
"options": {"num_ctx": 8192},
|
||||
},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=768,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: ollama_embed(
|
||||
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Add initialization code
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global rag
|
||||
rag = await init()
|
||||
print("done!")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
# Data models
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
mode: str = "hybrid"
|
||||
only_need_context: bool = False
|
||||
|
||||
|
||||
class InsertRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
status: str
|
||||
data: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
# API routes
|
||||
@app.post("/query", response_model=Response)
|
||||
async def query_endpoint(request: QueryRequest):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: rag.query(
|
||||
request.query,
|
||||
param=QueryParam(
|
||||
mode=request.mode, only_need_context=request.only_need_context
|
||||
),
|
||||
),
|
||||
)
|
||||
return Response(status="success", data=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# insert by text
|
||||
@app.post("/insert", response_model=Response)
|
||||
async def insert_endpoint(request: InsertRequest):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
||||
return Response(status="success", message="Text inserted successfully")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# insert by file in payload
|
||||
@app.post("/insert_file", response_model=Response)
|
||||
async def insert_file(file: UploadFile = File(...)):
|
||||
try:
|
||||
file_content = await file.read()
|
||||
# Read file content
|
||||
try:
|
||||
content = file_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# If UTF-8 decoding fails, try other encodings
|
||||
content = file_content.decode("gbk")
|
||||
# Insert file content
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(content))
|
||||
|
||||
return Response(
|
||||
status="success",
|
||||
message=f"File content from {file.filename} inserted successfully",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# insert by local default file
|
||||
@app.post("/insert_default_file", response_model=Response)
|
||||
@app.get("/insert_default_file", response_model=Response)
|
||||
async def insert_default_file():
|
||||
try:
|
||||
# Read file content from book.txt
|
||||
async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file:
|
||||
content = await file.read()
|
||||
print(f"read input file {INPUT_FILE} successfully")
|
||||
# Insert file content
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(content))
|
||||
|
||||
return Response(
|
||||
status="success",
|
||||
message=f"File content from {INPUT_FILE} inserted successfully",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8020)
|
||||
|
||||
# Usage example
|
||||
# To run the server, use the following command in your terminal:
|
||||
# python lightrag_api_openai_compatible_demo.py
|
||||
|
||||
# Example requests:
|
||||
# 1. Query:
|
||||
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
||||
|
||||
# 2. Insert text:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
||||
|
||||
# 3. Insert file:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
||||
|
||||
# 4. Health check:
|
||||
# curl -X GET "http://127.0.0.1:8020/health"
|
@@ -1,204 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
DEFAULT_RAG_DIR = "index_default"
|
||||
app = FastAPI(title="LightRAG API", description="API for RAG operations")
|
||||
|
||||
# Configure working directory
|
||||
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
||||
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
|
||||
print(f"LLM_MODEL: {LLM_MODEL}")
|
||||
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
|
||||
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
||||
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
||||
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
||||
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
|
||||
print(f"BASE_URL: {BASE_URL}")
|
||||
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
|
||||
print(f"API_KEY: {API_KEY}")
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
# LLM model function
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
model=LLM_MODEL,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
base_url=BASE_URL,
|
||||
api_key=API_KEY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Embedding function
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embed(
|
||||
texts=texts,
|
||||
model=EMBEDDING_MODEL,
|
||||
base_url=BASE_URL,
|
||||
api_key=API_KEY,
|
||||
)
|
||||
|
||||
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
print(f"{embedding_dim=}")
|
||||
return embedding_dim
|
||||
|
||||
|
||||
# Initialize RAG instance
|
||||
async def init():
|
||||
embedding_dimension = await get_embedding_dim()
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global rag
|
||||
rag = await init()
|
||||
print("done!")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
||||
)
|
||||
|
||||
# Data models
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
mode: str = "hybrid"
|
||||
only_need_context: bool = False
|
||||
|
||||
|
||||
class InsertRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
status: str
|
||||
data: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
# API routes
|
||||
|
||||
|
||||
@app.post("/query", response_model=Response)
|
||||
async def query_endpoint(request: QueryRequest):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: rag.query(
|
||||
request.query,
|
||||
param=QueryParam(
|
||||
mode=request.mode, only_need_context=request.only_need_context
|
||||
),
|
||||
),
|
||||
)
|
||||
return Response(status="success", data=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/insert", response_model=Response)
|
||||
async def insert_endpoint(request: InsertRequest):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
||||
return Response(status="success", message="Text inserted successfully")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/insert_file", response_model=Response)
|
||||
async def insert_file(file: UploadFile = File(...)):
|
||||
try:
|
||||
file_content = await file.read()
|
||||
# Read file content
|
||||
try:
|
||||
content = file_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# If UTF-8 decoding fails, try other encodings
|
||||
content = file_content.decode("gbk")
|
||||
# Insert file content
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(content))
|
||||
|
||||
return Response(
|
||||
status="success",
|
||||
message=f"File content from {file.filename} inserted successfully",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8020)
|
||||
|
||||
# Usage example
|
||||
# To run the server, use the following command in your terminal:
|
||||
# python lightrag_api_openai_compatible_demo.py
|
||||
|
||||
# Example requests:
|
||||
# 1. Query:
|
||||
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
||||
|
||||
# 2. Insert text:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
||||
|
||||
# 3. Insert file:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
||||
|
||||
# 4. Health check:
|
||||
# curl -X GET "http://127.0.0.1:8020/health"
|
@@ -1,267 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi import Query
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Any
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
|
||||
|
||||
print(os.getcwd())
|
||||
script_directory = Path(__file__).resolve().parent.parent
|
||||
sys.path.append(os.path.abspath(script_directory))
|
||||
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
DEFAULT_RAG_DIR = "index_default"
|
||||
|
||||
|
||||
# We use OpenAI compatible API to call LLM on Oracle Cloud
|
||||
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
|
||||
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
|
||||
APIKEY = "ocigenerativeai"
|
||||
|
||||
# Configure working directory
|
||||
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
||||
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
|
||||
print(f"LLM_MODEL: {LLM_MODEL}")
|
||||
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
|
||||
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
||||
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
|
||||
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
os.environ["ORACLE_USER"] = ""
|
||||
os.environ["ORACLE_PASSWORD"] = ""
|
||||
os.environ["ORACLE_DSN"] = ""
|
||||
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
|
||||
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
|
||||
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
|
||||
os.environ["ORACLE_WORKSPACE"] = "company"
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
LLM_MODEL,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=APIKEY,
|
||||
base_url=BASE_URL,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embed(
|
||||
texts,
|
||||
model=EMBEDDING_MODEL,
|
||||
api_key=APIKEY,
|
||||
base_url=BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
return embedding_dim
|
||||
|
||||
|
||||
async def init():
|
||||
# Detect embedding dimension
|
||||
embedding_dimension = await get_embedding_dim()
|
||||
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||
# Create Oracle DB connection
|
||||
# The `config` parameter is the connection configuration of Oracle DB
|
||||
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
||||
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
||||
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
||||
|
||||
# Initialize LightRAG
|
||||
# We use Oracle DB as the KV/vector/graph storage
|
||||
rag = LightRAG(
|
||||
enable_llm_cache=False,
|
||||
working_dir=WORKING_DIR,
|
||||
chunk_token_size=512,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=512,
|
||||
func=embedding_func,
|
||||
),
|
||||
graph_storage="OracleGraphStorage",
|
||||
kv_storage="OracleKVStorage",
|
||||
vector_storage="OracleVectorDBStorage",
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
# Extract and Insert into LightRAG storage
|
||||
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
|
||||
# await rag.ainsert(f.read())
|
||||
|
||||
# # Perform search in different modes
|
||||
# modes = ["naive", "local", "global", "hybrid"]
|
||||
# for mode in modes:
|
||||
# print("="*20, mode, "="*20)
|
||||
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
|
||||
# print("-"*100, "\n")
|
||||
|
||||
# Data models
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
mode: str = "hybrid"
|
||||
only_need_context: bool = False
|
||||
only_need_prompt: bool = False
|
||||
|
||||
|
||||
class DataRequest(BaseModel):
|
||||
limit: int = 100
|
||||
|
||||
|
||||
class InsertRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
status: str
|
||||
data: Optional[Any] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
# API routes
|
||||
|
||||
rag = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global rag
|
||||
rag = await init()
|
||||
print("done!")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
@app.post("/query", response_model=Response)
|
||||
async def query_endpoint(request: QueryRequest):
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
if request.mode == "naive":
|
||||
top_k = 3
|
||||
else:
|
||||
top_k = 60
|
||||
result = await rag.aquery(
|
||||
request.query,
|
||||
param=QueryParam(
|
||||
mode=request.mode,
|
||||
only_need_context=request.only_need_context,
|
||||
only_need_prompt=request.only_need_prompt,
|
||||
top_k=top_k,
|
||||
),
|
||||
)
|
||||
return Response(status="success", data=result)
|
||||
# except Exception as e:
|
||||
# raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/data", response_model=Response)
|
||||
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
|
||||
if type == "nodes":
|
||||
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
|
||||
elif type == "edges":
|
||||
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
|
||||
elif type == "statistics":
|
||||
result = await rag.chunk_entity_relation_graph.get_statistics()
|
||||
return Response(status="success", data=result)
|
||||
|
||||
|
||||
@app.post("/insert", response_model=Response)
|
||||
async def insert_endpoint(request: InsertRequest):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
||||
return Response(status="success", message="Text inserted successfully")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/insert_file", response_model=Response)
|
||||
async def insert_file(file: UploadFile = File(...)):
|
||||
try:
|
||||
file_content = await file.read()
|
||||
# Read file content
|
||||
try:
|
||||
content = file_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# If UTF-8 decoding fails, try other encodings
|
||||
content = file_content.decode("gbk")
|
||||
# Insert file content
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(content))
|
||||
|
||||
return Response(
|
||||
status="success",
|
||||
message=f"File content from {file.filename} inserted successfully",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8020)
|
||||
|
||||
# Usage example
|
||||
# To run the server, use the following command in your terminal:
|
||||
# python lightrag_api_openai_compatible_demo.py
|
||||
|
||||
# Example requests:
|
||||
# 1. Query:
|
||||
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
||||
|
||||
# 2. Insert text:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
||||
|
||||
# 3. Insert file:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
||||
|
||||
|
||||
# 4. Health check:
|
||||
# curl -X GET "http://127.0.0.1:8020/health"
|
@@ -1,3 +1,7 @@
|
||||
##############################################
|
||||
# Gremlin storage implementation is deprecated
|
||||
##############################################
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
|
@@ -1,141 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
|
||||
print(os.getcwd())
|
||||
script_directory = Path(__file__).resolve().parent.parent
|
||||
sys.path.append(os.path.abspath(script_directory))
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
# We use OpenAI compatible API to call LLM on Oracle Cloud
|
||||
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
|
||||
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
|
||||
APIKEY = "ocigenerativeai"
|
||||
CHATMODEL = "cohere.command-r-plus"
|
||||
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
|
||||
CHUNK_TOKEN_SIZE = 1024
|
||||
MAX_TOKENS = 4000
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
os.environ["ORACLE_USER"] = "username"
|
||||
os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
|
||||
os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
|
||||
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
|
||||
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
|
||||
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
|
||||
os.environ["ORACLE_WORKSPACE"] = "company"
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
CHATMODEL,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=APIKEY,
|
||||
base_url=BASE_URL,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embed(
|
||||
texts,
|
||||
model=EMBEDMODEL,
|
||||
api_key=APIKEY,
|
||||
base_url=BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
return embedding_dim
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
# Detect embedding dimension
|
||||
embedding_dimension = await get_embedding_dim()
|
||||
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||
|
||||
# Initialize LightRAG
|
||||
# We use Oracle DB as the KV/vector/graph storage
|
||||
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
||||
rag = LightRAG(
|
||||
# log_level="DEBUG",
|
||||
working_dir=WORKING_DIR,
|
||||
entity_extract_max_gleaning=1,
|
||||
enable_llm_cache=True,
|
||||
enable_llm_cache_for_entity_extract=True,
|
||||
embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
|
||||
chunk_token_size=CHUNK_TOKEN_SIZE,
|
||||
llm_model_max_token_size=MAX_TOKENS,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=500,
|
||||
func=embedding_func,
|
||||
),
|
||||
graph_storage="OracleGraphStorage",
|
||||
kv_storage="OracleKVStorage",
|
||||
vector_storage="OracleVectorDBStorage",
|
||||
addon_params={
|
||||
"example_number": 1,
|
||||
"language": "Simplfied Chinese",
|
||||
"entity_types": ["organization", "person", "geo", "event"],
|
||||
"insert_batch_size": 2,
|
||||
},
|
||||
)
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Initialize RAG instance
|
||||
rag = await initialize_rag()
|
||||
|
||||
# Extract and Insert into LightRAG storage
|
||||
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
||||
all_text = f.read()
|
||||
texts = [x for x in all_text.split("\n") if x]
|
||||
|
||||
# New mode use pipeline
|
||||
await rag.apipeline_enqueue_documents(texts)
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
|
||||
# Old method use ainsert
|
||||
# await rag.ainsert(texts)
|
||||
|
||||
# Perform search in different modes
|
||||
modes = ["naive", "local", "global", "hybrid"]
|
||||
for mode in modes:
|
||||
print("=" * 20, mode, "=" * 20)
|
||||
print(
|
||||
await rag.aquery(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(mode=mode),
|
||||
)
|
||||
)
|
||||
print("-" * 100, "\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@@ -1,3 +1,7 @@
|
||||
###########################################
|
||||
# TiDB storage implementation is deprecated
|
||||
###########################################
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
|
@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
|
||||
|
||||
```
|
||||
JsonKVStorage JsonFile(默认)
|
||||
MongoKVStorage MogonDB
|
||||
RedisKVStorage Redis
|
||||
TiDBKVStorage TiDB
|
||||
PGKVStorage Postgres
|
||||
OracleKVStorage Oracle
|
||||
RedisKVStorage Redis
|
||||
MongoKVStorage MogonDB
|
||||
```
|
||||
|
||||
* GRAPH_STORAGE 支持的实现名称
|
||||
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
|
||||
```
|
||||
NetworkXStorage NetworkX(默认)
|
||||
Neo4JStorage Neo4J
|
||||
MongoGraphStorage MongoDB
|
||||
TiDBGraphStorage TiDB
|
||||
AGEStorage AGE
|
||||
GremlinStorage Gremlin
|
||||
PGGraphStorage Postgres
|
||||
OracleGraphStorage Postgres
|
||||
AGEStorage AGE
|
||||
```
|
||||
|
||||
* VECTOR_STORAGE 支持的实现名称
|
||||
|
||||
```
|
||||
NanoVectorDBStorage NanoVector(默认)
|
||||
PGVectorStorage Postgres
|
||||
MilvusVectorDBStorge Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
FaissVectorDBStorage Faiss
|
||||
QdrantVectorDBStorage Qdrant
|
||||
OracleVectorDBStorage Oracle
|
||||
MongoVectorDBStorage MongoDB
|
||||
```
|
||||
|
||||
|
@@ -302,11 +302,9 @@ Each storage type have servals implementations:
|
||||
|
||||
```
|
||||
JsonKVStorage JsonFile(default)
|
||||
MongoKVStorage MogonDB
|
||||
RedisKVStorage Redis
|
||||
TiDBKVStorage TiDB
|
||||
PGKVStorage Postgres
|
||||
OracleKVStorage Oracle
|
||||
RedisKVStorage Redis
|
||||
MongoKVStorage MogonDB
|
||||
```
|
||||
|
||||
* GRAPH_STORAGE supported implement-name
|
||||
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
|
||||
```
|
||||
NetworkXStorage NetworkX(defualt)
|
||||
Neo4JStorage Neo4J
|
||||
MongoGraphStorage MongoDB
|
||||
TiDBGraphStorage TiDB
|
||||
AGEStorage AGE
|
||||
GremlinStorage Gremlin
|
||||
PGGraphStorage Postgres
|
||||
OracleGraphStorage Postgres
|
||||
AGEStorage AGE
|
||||
```
|
||||
|
||||
* VECTOR_STORAGE supported implement-name
|
||||
|
||||
```
|
||||
NanoVectorDBStorage NanoVector(default)
|
||||
MilvusVectorDBStorage Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
TiDBVectorDBStorage TiDB
|
||||
PGVectorStorage Postgres
|
||||
MilvusVectorDBStorge Milvus
|
||||
ChromaVectorDBStorage Chroma
|
||||
FaissVectorDBStorage Faiss
|
||||
QdrantVectorDBStorage Qdrant
|
||||
OracleVectorDBStorage Oracle
|
||||
MongoVectorDBStorage MongoDB
|
||||
```
|
||||
|
||||
|
@@ -1 +1 @@
|
||||
__api_version__ = "1.2.8"
|
||||
__api_version__ = "0136"
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .config import global_args
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
|
||||
|
||||
class AuthHandler:
|
||||
def __init__(self):
|
||||
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
|
||||
self.algorithm = "HS256"
|
||||
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
|
||||
self.guest_expire_hours = int(os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2))
|
||||
|
||||
self.secret = global_args.token_secret
|
||||
self.algorithm = global_args.jwt_algorithm
|
||||
self.expire_hours = global_args.token_expire_hours
|
||||
self.guest_expire_hours = global_args.guest_token_expire_hours
|
||||
self.accounts = {}
|
||||
auth_accounts = os.getenv("AUTH_ACCOUNTS")
|
||||
auth_accounts = global_args.auth_accounts
|
||||
if auth_accounts:
|
||||
for account in auth_accounts.split(","):
|
||||
username, password = account.split(":", 1)
|
||||
|
335
lightrag/api/config.py
Normal file
335
lightrag/api/config.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Configs for the LightRAG API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
|
||||
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
|
||||
class DefaultRAGStorageConfig:
|
||||
KV_STORAGE = "JsonKVStorage"
|
||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
default_hosts = {
|
||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||
}
|
||||
return default_hosts.get(
|
||||
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||
) # fallback to ollama if unknown
|
||||
|
||||
|
||||
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
||||
"""
|
||||
Get value from environment variable with type conversion
|
||||
|
||||
Args:
|
||||
env_key (str): Environment variable key
|
||||
default (any): Default value if env variable is not set
|
||||
value_type (type): Type to convert the value to
|
||||
|
||||
Returns:
|
||||
any: Converted value from environment or default
|
||||
"""
|
||||
value = os.getenv(env_key)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if value_type is bool:
|
||||
return value.lower() in ("true", "1", "yes", "t", "on")
|
||||
try:
|
||||
return value_type(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments with environment variable fallback
|
||||
|
||||
Args:
|
||||
is_uvicorn_mode: Whether running under uvicorn mode
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=get_env_value("HOST", "0.0.0.0"),
|
||||
help="Server host (default: from env or 0.0.0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=get_env_value("PORT", 9621, int),
|
||||
help="Server port (default: from env or 9621)",
|
||||
)
|
||||
|
||||
# Directory configuration
|
||||
parser.add_argument(
|
||||
"--working-dir",
|
||||
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
||||
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
default=get_env_value("INPUT_DIR", "./inputs"),
|
||||
help="Directory containing input documents (default: from env or ./inputs)",
|
||||
)
|
||||
|
||||
def timeout_type(value):
|
||||
if value is None:
|
||||
return 150
|
||||
if value is None or value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
default=get_env_value("TIMEOUT", None, timeout_type),
|
||||
type=timeout_type,
|
||||
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
||||
)
|
||||
|
||||
# RAG configuration
|
||||
parser.add_argument(
|
||||
"--max-async",
|
||||
type=int,
|
||||
default=get_env_value("MAX_ASYNC", 4, int),
|
||||
help="Maximum async operations (default: from env or 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=get_env_value("MAX_TOKENS", 32768, int),
|
||||
help="Maximum token size (default: from env or 32768)",
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default=get_env_value("LOG_LEVEL", "INFO"),
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Logging level (default: from env or INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=get_env_value("VERBOSE", False, bool),
|
||||
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
type=str,
|
||||
default=get_env_value("LIGHTRAG_API_KEY", None),
|
||||
help="API key for authentication. This protects lightrag server against unauthorized access",
|
||||
)
|
||||
|
||||
# Optional https parameters
|
||||
parser.add_argument(
|
||||
"--ssl",
|
||||
action="store_true",
|
||||
default=get_env_value("SSL", False, bool),
|
||||
help="Enable HTTPS (default: from env or False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
default=get_env_value("SSL_CERTFILE", None),
|
||||
help="Path to SSL certificate file (required if --ssl is enabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-keyfile",
|
||||
default=get_env_value("SSL_KEYFILE", None),
|
||||
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--history-turns",
|
||||
type=int,
|
||||
default=get_env_value("HISTORY_TURNS", 3, int),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=get_env_value("TOP_K", 60, int),
|
||||
help="Number of most similar results to return (default: from env or 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cosine-threshold",
|
||||
type=float,
|
||||
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||
)
|
||||
|
||||
# Ollama model name
|
||||
parser.add_argument(
|
||||
"--simulated-model-name",
|
||||
type=str,
|
||||
default=get_env_value(
|
||||
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
||||
),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Namespace
|
||||
parser.add_argument(
|
||||
"--namespace-prefix",
|
||||
type=str,
|
||||
default=get_env_value("NAMESPACE_PREFIX", ""),
|
||||
help="Prefix of the namespace",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--auto-scan-at-startup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable automatic scanning when the program starts",
|
||||
)
|
||||
|
||||
# Server workers configuration
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=get_env_value("WORKERS", 1, int),
|
||||
help="Number of worker processes (default: from env or 1)",
|
||||
)
|
||||
|
||||
# LLM and embedding bindings
|
||||
parser.add_argument(
|
||||
"--llm-binding",
|
||||
type=str,
|
||||
default=get_env_value("LLM_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
||||
help="LLM binding type (default: from env or ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-binding",
|
||||
type=str,
|
||||
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "azure_openai"],
|
||||
help="Embedding binding type (default: from env or ollama)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# convert relative path to absolute path
|
||||
args.working_dir = os.path.abspath(args.working_dir)
|
||||
args.input_dir = os.path.abspath(args.input_dir)
|
||||
|
||||
# Inject storage configuration from environment variables
|
||||
args.kv_storage = get_env_value(
|
||||
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||
)
|
||||
args.doc_status_storage = get_env_value(
|
||||
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||
)
|
||||
args.graph_storage = get_env_value(
|
||||
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||
)
|
||||
args.vector_storage = get_env_value(
|
||||
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||
)
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from environment
|
||||
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
|
||||
# Handle openai-ollama special case
|
||||
if args.llm_binding == "openai-ollama":
|
||||
args.llm_binding = "openai"
|
||||
args.embedding_binding = "ollama"
|
||||
|
||||
args.llm_binding_host = get_env_value(
|
||||
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
||||
)
|
||||
args.embedding_binding_host = get_env_value(
|
||||
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
||||
)
|
||||
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
||||
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
||||
|
||||
# Inject model configuration
|
||||
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
||||
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
||||
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
||||
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
||||
|
||||
# Inject chunk configuration
|
||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||
|
||||
# Inject LLM cache configuration
|
||||
args.enable_llm_cache_for_extract = get_env_value(
|
||||
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
||||
)
|
||||
|
||||
# Inject LLM temperature configuration
|
||||
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
|
||||
# Add environment variables that were previously read directly
|
||||
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
|
||||
args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
|
||||
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
|
||||
|
||||
# For JWT Auth
|
||||
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
|
||||
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
|
||||
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
|
||||
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
|
||||
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
||||
|
||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def update_uvicorn_mode_config():
|
||||
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
||||
if global_args.workers > 1:
|
||||
original_workers = global_args.workers
|
||||
global_args.workers = 1
|
||||
# Log warning directly here
|
||||
logging.warning(
|
||||
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
||||
)
|
||||
|
||||
|
||||
global_args = parse_args()
|
@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
|
||||
from dotenv import load_dotenv
|
||||
from lightrag.api.utils_api import (
|
||||
get_combined_auth_dependency,
|
||||
parse_args,
|
||||
get_default_host,
|
||||
display_splash_screen,
|
||||
check_env_file,
|
||||
)
|
||||
from .config import (
|
||||
global_args,
|
||||
update_uvicorn_mode_config,
|
||||
get_default_host,
|
||||
)
|
||||
import sys
|
||||
from lightrag import LightRAG, __version__ as core_version
|
||||
from lightrag.api import __api_version__
|
||||
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
webui_title = os.getenv("WEBUI_TITLE")
|
||||
webui_description = os.getenv("WEBUI_DESCRIPTION")
|
||||
|
||||
# Initialize config parser
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
@@ -164,10 +171,10 @@ def create_app(args):
|
||||
app = FastAPI(**app_kwargs)
|
||||
|
||||
def get_cors_origins():
|
||||
"""Get allowed origins from environment variable
|
||||
"""Get allowed origins from global_args
|
||||
Returns a list of allowed origins, defaults to ["*"] if not set
|
||||
"""
|
||||
origins_str = os.getenv("CORS_ORIGINS", "*")
|
||||
origins_str = global_args.cors_origins
|
||||
if origins_str == "*":
|
||||
return ["*"]
|
||||
return [origin.strip() for origin in origins_str.split(",")]
|
||||
@@ -315,9 +322,10 @@ def create_app(args):
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
# namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
addon_params={"language": args.summary_language},
|
||||
)
|
||||
else: # azure_openai
|
||||
rag = LightRAG(
|
||||
@@ -345,9 +353,10 @@ def create_app(args):
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
namespace_prefix=args.namespace_prefix,
|
||||
# namespace_prefix=args.namespace_prefix,
|
||||
auto_manage_storages_states=False,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
addon_params={"language": args.summary_language},
|
||||
)
|
||||
|
||||
# Add routes
|
||||
@@ -381,6 +390,8 @@ def create_app(args):
|
||||
"message": "Authentication is disabled. Using guest access.",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -388,6 +399,8 @@ def create_app(args):
|
||||
"auth_mode": "enabled",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
|
||||
@app.post("/login")
|
||||
@@ -404,6 +417,8 @@ def create_app(args):
|
||||
"message": "Authentication is disabled. Using guest access.",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
username = form_data.username
|
||||
if auth_handler.accounts.get(username) != form_data.password:
|
||||
@@ -421,6 +436,8 @@ def create_app(args):
|
||||
"auth_mode": "enabled",
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||
@@ -454,10 +471,12 @@ def create_app(args):
|
||||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
},
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"auth_mode": auth_mode,
|
||||
"pipeline_busy": pipeline_status.get("busy", False),
|
||||
"core_version": core_version,
|
||||
"api_version": __api_version__,
|
||||
"webui_title": webui_title,
|
||||
"webui_description": webui_description,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting health status: {str(e)}")
|
||||
@@ -490,7 +509,7 @@ def create_app(args):
|
||||
def get_application(args=None):
|
||||
"""Factory function for creating the FastAPI application"""
|
||||
if args is None:
|
||||
args = parse_args()
|
||||
args = global_args
|
||||
return create_app(args)
|
||||
|
||||
|
||||
@@ -611,30 +630,31 @@ def main():
|
||||
|
||||
# Configure logging before parsing args
|
||||
configure_logging()
|
||||
|
||||
args = parse_args(is_uvicorn_mode=True)
|
||||
display_splash_screen(args)
|
||||
update_uvicorn_mode_config()
|
||||
display_splash_screen(global_args)
|
||||
|
||||
# Create application instance directly instead of using factory function
|
||||
app = create_app(args)
|
||||
app = create_app(global_args)
|
||||
|
||||
# Start Uvicorn in single process mode
|
||||
uvicorn_config = {
|
||||
"app": app, # Pass application instance directly instead of string path
|
||||
"host": args.host,
|
||||
"port": args.port,
|
||||
"host": global_args.host,
|
||||
"port": global_args.port,
|
||||
"log_config": None, # Disable default config
|
||||
}
|
||||
|
||||
if args.ssl:
|
||||
if global_args.ssl:
|
||||
uvicorn_config.update(
|
||||
{
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
"ssl_keyfile": args.ssl_keyfile,
|
||||
"ssl_certfile": global_args.ssl_certfile,
|
||||
"ssl_keyfile": global_args.ssl_keyfile,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
|
||||
print(
|
||||
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
|
||||
)
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
|
||||
|
@@ -10,16 +10,14 @@ import traceback
|
||||
import pipmaster as pm
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Literal
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.api.utils_api import (
|
||||
get_combined_auth_dependency,
|
||||
global_args,
|
||||
)
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from ..config import global_args
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
@@ -30,7 +28,37 @@ router = APIRouter(
|
||||
temp_prefix = "__tmp__"
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
"""Response model for document scanning operation
|
||||
|
||||
Attributes:
|
||||
status: Status of the scanning operation
|
||||
message: Optional message with additional details
|
||||
"""
|
||||
|
||||
status: Literal["scanning_started"] = Field(
|
||||
description="Status of the scanning operation"
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Additional details about the scanning operation"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "scanning_started",
|
||||
"message": "Scanning process has been initiated in the background",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InsertTextRequest(BaseModel):
|
||||
"""Request model for inserting a single text document
|
||||
|
||||
Attributes:
|
||||
text: The text content to be inserted into the RAG system
|
||||
"""
|
||||
|
||||
text: str = Field(
|
||||
min_length=1,
|
||||
description="The text to insert",
|
||||
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
|
||||
def strip_after(cls, text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "This is a sample text to be inserted into the RAG system."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InsertTextsRequest(BaseModel):
|
||||
"""Request model for inserting multiple text documents
|
||||
|
||||
Attributes:
|
||||
texts: List of text contents to be inserted into the RAG system
|
||||
"""
|
||||
|
||||
texts: list[str] = Field(
|
||||
min_length=1,
|
||||
description="The texts to insert",
|
||||
@@ -53,20 +94,100 @@ class InsertTextsRequest(BaseModel):
|
||||
def strip_after(cls, texts: list[str]) -> list[str]:
|
||||
return [text.strip() for text in texts]
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"texts": [
|
||||
"This is the first text to be inserted.",
|
||||
"This is the second text to be inserted.",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InsertResponse(BaseModel):
|
||||
status: str = Field(description="Status of the operation")
|
||||
"""Response model for document insertion operations
|
||||
|
||||
Attributes:
|
||||
status: Status of the operation (success, duplicated, partial_success, failure)
|
||||
message: Detailed message describing the operation result
|
||||
"""
|
||||
|
||||
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
|
||||
description="Status of the operation"
|
||||
)
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "success",
|
||||
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ClearDocumentsResponse(BaseModel):
|
||||
"""Response model for document clearing operation
|
||||
|
||||
Attributes:
|
||||
status: Status of the clear operation
|
||||
message: Detailed message describing the operation result
|
||||
"""
|
||||
|
||||
status: Literal["success", "partial_success", "busy", "fail"] = Field(
|
||||
description="Status of the clear operation"
|
||||
)
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "success",
|
||||
"message": "All documents cleared successfully. Deleted 15 files.",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ClearCacheRequest(BaseModel):
|
||||
"""Request model for clearing cache
|
||||
|
||||
Attributes:
|
||||
modes: Optional list of cache modes to clear
|
||||
"""
|
||||
|
||||
modes: Optional[
|
||||
List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
|
||||
] = Field(
|
||||
default=None,
|
||||
description="Modes of cache to clear. If None, clears all cache.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {"example": {"modes": ["default", "naive"]}}
|
||||
|
||||
|
||||
class ClearCacheResponse(BaseModel):
|
||||
"""Response model for cache clearing operation
|
||||
|
||||
Attributes:
|
||||
status: Status of the clear operation
|
||||
message: Detailed message describing the operation result
|
||||
"""
|
||||
|
||||
status: Literal["success", "fail"] = Field(
|
||||
description="Status of the clear operation"
|
||||
)
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": "success",
|
||||
"message": "Successfully cleared cache for modes: ['default', 'naive']",
|
||||
}
|
||||
}
|
||||
|
||||
class DocStatusResponse(BaseModel):
|
||||
@staticmethod
|
||||
def format_datetime(dt: Any) -> Optional[str]:
|
||||
if dt is None:
|
||||
return None
|
||||
if isinstance(dt, str):
|
||||
return dt
|
||||
return dt.isoformat()
|
||||
|
||||
"""Response model for document status
|
||||
|
||||
@@ -80,22 +201,95 @@ class DocStatusResponse(BaseModel):
|
||||
chunks_count: Number of chunks (optional)
|
||||
error: Error message if any (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
file_path: Path to the document file
|
||||
"""
|
||||
|
||||
id: str
|
||||
content_summary: str
|
||||
content_length: int
|
||||
status: DocStatus
|
||||
created_at: str
|
||||
updated_at: str
|
||||
chunks_count: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
file_path: str
|
||||
|
||||
class DocStatusResponse(BaseModel):
|
||||
@staticmethod
|
||||
def format_datetime(dt: Any) -> Optional[str]:
|
||||
if dt is None:
|
||||
return None
|
||||
if isinstance(dt, str):
|
||||
return dt
|
||||
return dt.isoformat()
|
||||
|
||||
id: str = Field(description="Document identifier")
|
||||
content_summary: str = Field(description="Summary of document content")
|
||||
content_length: int = Field(description="Length of document content in characters")
|
||||
status: DocStatus = Field(description="Current processing status")
|
||||
created_at: str = Field(description="Creation timestamp (ISO format string)")
|
||||
updated_at: str = Field(description="Last update timestamp (ISO format string)")
|
||||
chunks_count: Optional[int] = Field(
|
||||
default=None, description="Number of chunks the document was split into"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="Error message if processing failed"
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
default=None, description="Additional metadata about the document"
|
||||
)
|
||||
file_path: str = Field(description="Path to the document file")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"id": "doc_123456",
|
||||
"content_summary": "Research paper on machine learning",
|
||||
"content_length": 15240,
|
||||
"status": "PROCESSED",
|
||||
"created_at": "2025-03-31T12:34:56",
|
||||
"updated_at": "2025-03-31T12:35:30",
|
||||
"chunks_count": 12,
|
||||
"error": None,
|
||||
"metadata": {"author": "John Doe", "year": 2025},
|
||||
"file_path": "research_paper.pdf",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DocsStatusesResponse(BaseModel):
|
||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
||||
"""Response model for document statuses
|
||||
|
||||
Attributes:
|
||||
statuses: Dictionary mapping document status to lists of document status responses
|
||||
"""
|
||||
|
||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary mapping document status to lists of document status responses",
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"statuses": {
|
||||
"PENDING": [
|
||||
{
|
||||
"id": "doc_123",
|
||||
"content_summary": "Pending document",
|
||||
"content_length": 5000,
|
||||
"status": "PENDING",
|
||||
"created_at": "2025-03-31T10:00:00",
|
||||
"updated_at": "2025-03-31T10:00:00",
|
||||
"file_path": "pending_doc.pdf",
|
||||
}
|
||||
],
|
||||
"PROCESSED": [
|
||||
{
|
||||
"id": "doc_456",
|
||||
"content_summary": "Processed document",
|
||||
"content_length": 8000,
|
||||
"status": "PROCESSED",
|
||||
"created_at": "2025-03-31T09:00:00",
|
||||
"updated_at": "2025-03-31T09:05:00",
|
||||
"chunks_count": 8,
|
||||
"file_path": "processed_doc.pdf",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PipelineStatusResponse(BaseModel):
|
||||
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
)
|
||||
return False
|
||||
case ".pdf":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
for page in reader.pages:
|
||||
content += page.extract_text() + "\n"
|
||||
case ".docx":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
[paragraph.text for paragraph in doc.paragraphs]
|
||||
)
|
||||
case ".pptx":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
if hasattr(shape, "text"):
|
||||
content += shape.text + "\n"
|
||||
case ".xlsx":
|
||||
if global_args["main_args"].document_loading_engine == "DOCLING":
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
|
||||
|
||||
# TODO: deprecate after /insert_file is removed
|
||||
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||
"""Save the uploaded file to a temporary location
|
||||
|
||||
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
if not new_files:
|
||||
return
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from global_args["main_args"]
|
||||
max_parallel = global_args["main_args"].max_parallel_insert
|
||||
# Get MAX_PARALLEL_INSERT from global_args
|
||||
max_parallel = global_args.max_parallel_insert
|
||||
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
|
||||
batch_size = 2 * max_parallel
|
||||
|
||||
@@ -509,7 +704,9 @@ def create_document_routes(
|
||||
# Create combined auth dependency for document routes
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
||||
@router.post(
|
||||
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Trigger the scanning process for new documents.
|
||||
@@ -519,13 +716,18 @@ def create_document_routes(
|
||||
that fact.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the scanning status
|
||||
ScanResponse: A response object containing the scanning status
|
||||
"""
|
||||
# Start the scanning process in the background
|
||||
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
||||
return {"status": "scanning_started"}
|
||||
return ScanResponse(
|
||||
status="scanning_started",
|
||||
message="Scanning process has been initiated in the background",
|
||||
)
|
||||
|
||||
@router.post("/upload", dependencies=[Depends(combined_auth)])
|
||||
@router.post(
|
||||
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def upload_to_input_dir(
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
):
|
||||
@@ -645,6 +847,7 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# TODO: deprecated, use /upload instead
|
||||
@router.post(
|
||||
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
@@ -688,6 +891,7 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# TODO: deprecated, use /upload instead
|
||||
@router.post(
|
||||
"/file_batch",
|
||||
response_model=InsertResponse,
|
||||
@@ -752,32 +956,186 @@ def create_document_routes(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete(
|
||||
"", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
||||
"", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def clear_documents():
|
||||
"""
|
||||
Clear all documents from the RAG system.
|
||||
|
||||
This endpoint deletes all text chunks, entities vector database, and relationships
|
||||
vector database, effectively clearing all documents from the RAG system.
|
||||
This endpoint deletes all documents, entities, relationships, and files from the system.
|
||||
It uses the storage drop methods to properly clean up all data and removes all files
|
||||
from the input directory.
|
||||
|
||||
Returns:
|
||||
InsertResponse: A response object containing the status and message.
|
||||
ClearDocumentsResponse: A response object containing the status and message.
|
||||
- status="success": All documents and files were successfully cleared.
|
||||
- status="partial_success": Document clear job exit with some errors.
|
||||
- status="busy": Operation could not be completed because the pipeline is busy.
|
||||
- status="fail": All storage drop operations failed, with message
|
||||
- message: Detailed information about the operation results, including counts
|
||||
of deleted files and any errors encountered.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during the clearing process (500).
|
||||
HTTPException: Raised when a serious error occurs during the clearing process,
|
||||
with status code 500 and error details in the detail field.
|
||||
"""
|
||||
try:
|
||||
rag.text_chunks = []
|
||||
rag.entities_vdb = None
|
||||
rag.relationships_vdb = None
|
||||
return InsertResponse(
|
||||
status="success", message="All documents cleared successfully"
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_pipeline_status_lock,
|
||||
)
|
||||
|
||||
# Get pipeline status and lock
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
pipeline_status_lock = get_pipeline_status_lock()
|
||||
|
||||
# Check and set status with lock
|
||||
async with pipeline_status_lock:
|
||||
if pipeline_status.get("busy", False):
|
||||
return ClearDocumentsResponse(
|
||||
status="busy",
|
||||
message="Cannot clear documents while pipeline is busy",
|
||||
)
|
||||
# Set busy to true
|
||||
pipeline_status.update(
|
||||
{
|
||||
"busy": True,
|
||||
"job_name": "Clearing Documents",
|
||||
"job_start": datetime.now().isoformat(),
|
||||
"docs": 0,
|
||||
"batchs": 0,
|
||||
"cur_batch": 0,
|
||||
"request_pending": False, # Clear any previous request
|
||||
"latest_message": "Starting document clearing process",
|
||||
}
|
||||
)
|
||||
# Cleaning history_messages without breaking it as a shared list object
|
||||
del pipeline_status["history_messages"][:]
|
||||
pipeline_status["history_messages"].append(
|
||||
"Starting document clearing process"
|
||||
)
|
||||
|
||||
try:
|
||||
# Use drop method to clear all data
|
||||
drop_tasks = []
|
||||
storages = [
|
||||
rag.text_chunks,
|
||||
rag.full_docs,
|
||||
rag.entities_vdb,
|
||||
rag.relationships_vdb,
|
||||
rag.chunks_vdb,
|
||||
rag.chunk_entity_relation_graph,
|
||||
rag.doc_status,
|
||||
]
|
||||
|
||||
# Log storage drop start
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(
|
||||
"Starting to drop storage components"
|
||||
)
|
||||
|
||||
for storage in storages:
|
||||
if storage is not None:
|
||||
drop_tasks.append(storage.drop())
|
||||
|
||||
# Wait for all drop tasks to complete
|
||||
drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
|
||||
|
||||
# Check for errors and log results
|
||||
errors = []
|
||||
storage_success_count = 0
|
||||
storage_error_count = 0
|
||||
|
||||
for i, result in enumerate(drop_results):
|
||||
storage_name = storages[i].__class__.__name__
|
||||
if isinstance(result, Exception):
|
||||
error_msg = f"Error dropping {storage_name}: {str(result)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
storage_error_count += 1
|
||||
else:
|
||||
logger.info(f"Successfully dropped {storage_name}")
|
||||
storage_success_count += 1
|
||||
|
||||
# Log storage drop results
|
||||
if "history_messages" in pipeline_status:
|
||||
if storage_error_count > 0:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
|
||||
)
|
||||
else:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Successfully dropped all {storage_success_count} storage components"
|
||||
)
|
||||
|
||||
# If all storage operations failed, return error status and don't proceed with file deletion
|
||||
if storage_success_count == 0 and storage_error_count > 0:
|
||||
error_message = "All storage drop operations failed. Aborting document clearing process."
|
||||
logger.error(error_message)
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(error_message)
|
||||
return ClearDocumentsResponse(status="fail", message=error_message)
|
||||
|
||||
# Log file deletion start
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(
|
||||
"Starting to delete files in input directory"
|
||||
)
|
||||
|
||||
# Delete all files in input_dir
|
||||
deleted_files_count = 0
|
||||
file_errors_count = 0
|
||||
|
||||
for file_path in doc_manager.input_dir.glob("**/*"):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error DELETE /documents: {str(e)}")
|
||||
logger.error(f"Error deleting file {file_path}: {str(e)}")
|
||||
file_errors_count += 1
|
||||
|
||||
# Log file deletion results
|
||||
if "history_messages" in pipeline_status:
|
||||
if file_errors_count > 0:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Deleted {deleted_files_count} files with {file_errors_count} errors"
|
||||
)
|
||||
errors.append(f"Failed to delete {file_errors_count} files")
|
||||
else:
|
||||
pipeline_status["history_messages"].append(
|
||||
f"Successfully deleted {deleted_files_count} files"
|
||||
)
|
||||
|
||||
# Prepare final result message
|
||||
final_message = ""
|
||||
if errors:
|
||||
final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
|
||||
status = "partial_success"
|
||||
else:
|
||||
final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
|
||||
status = "success"
|
||||
|
||||
# Log final result
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(final_message)
|
||||
|
||||
# Return response based on results
|
||||
return ClearDocumentsResponse(status=status, message=final_message)
|
||||
except Exception as e:
|
||||
error_msg = f"Error clearing documents: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error(traceback.format_exc())
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(error_msg)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
# Reset busy status after completion
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["busy"] = False
|
||||
completion_msg = "Document clearing process completed"
|
||||
pipeline_status["latest_message"] = completion_msg
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].append(completion_msg)
|
||||
|
||||
@router.get(
|
||||
"/pipeline_status",
|
||||
@@ -850,7 +1208,9 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("", dependencies=[Depends(combined_auth)])
|
||||
@router.get(
|
||||
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def documents() -> DocsStatusesResponse:
|
||||
"""
|
||||
Get the status of all documents in the system.
|
||||
@@ -908,4 +1268,57 @@ def create_document_routes(
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
"/clear_cache",
|
||||
response_model=ClearCacheResponse,
|
||||
dependencies=[Depends(combined_auth)],
|
||||
)
|
||||
async def clear_cache(request: ClearCacheRequest):
|
||||
"""
|
||||
Clear cache data from the LLM response cache storage.
|
||||
|
||||
This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
|
||||
Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
|
||||
- "default" represents extraction cache.
|
||||
- Other modes correspond to different query modes.
|
||||
|
||||
Args:
|
||||
request (ClearCacheRequest): The request body containing optional modes to clear.
|
||||
|
||||
Returns:
|
||||
ClearCacheResponse: A response object containing the status and message.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
|
||||
"""
|
||||
try:
|
||||
# Validate modes if provided
|
||||
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
|
||||
if request.modes and not all(mode in valid_modes for mode in request.modes):
|
||||
invalid_modes = [
|
||||
mode for mode in request.modes if mode not in valid_modes
|
||||
]
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
|
||||
)
|
||||
|
||||
# Call the aclear_cache method
|
||||
await rag.aclear_cache(request.modes)
|
||||
|
||||
# Prepare success message
|
||||
if request.modes:
|
||||
message = f"Successfully cleared cache for modes: {request.modes}"
|
||||
else:
|
||||
message = "Successfully cleared all cache"
|
||||
|
||||
return ClearCacheResponse(status="success", message=message)
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing cache: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return router
|
||||
|
@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from ..utils_api import get_combined_auth_dependency
|
||||
|
||||
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
||||
async def get_knowledge_graph(
|
||||
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||
label: str = Query(..., description="Label to get knowledge graph for"),
|
||||
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
||||
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
|
||||
):
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||
1. Hops(path) to the staring node take precedence
|
||||
2. Followed by the degree of the nodes
|
||||
|
||||
Args:
|
||||
label (str): Label to get knowledge graph for
|
||||
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
|
||||
inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
|
||||
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
|
||||
label (str): Label of the starting node
|
||||
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Knowledge graph for label
|
||||
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
return await rag.get_knowledge_graph(
|
||||
node_label=label,
|
||||
max_depth=max_depth,
|
||||
inclusive=inclusive,
|
||||
min_degree=min_degree,
|
||||
max_nodes=max_nodes,
|
||||
)
|
||||
|
||||
return router
|
||||
|
@@ -7,14 +7,9 @@ import os
|
||||
import sys
|
||||
import signal
|
||||
import pipmaster as pm
|
||||
from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file
|
||||
from lightrag.api.utils_api import display_splash_screen, check_env_file
|
||||
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
from .config import global_args
|
||||
|
||||
|
||||
def check_and_install_dependencies():
|
||||
@@ -59,20 +54,17 @@ def main():
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
||||
|
||||
# Parse all arguments using parse_args
|
||||
args = parse_args(is_uvicorn_mode=False)
|
||||
|
||||
# Display startup information
|
||||
display_splash_screen(args)
|
||||
display_splash_screen(global_args)
|
||||
|
||||
print("🚀 Starting LightRAG with Gunicorn")
|
||||
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
|
||||
print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
|
||||
print("🔍 Preloading app: Enabled")
|
||||
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
||||
print("\n\n" + "=" * 80)
|
||||
print("MAIN PROCESS INITIALIZATION")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print(f"Workers setting: {args.workers}")
|
||||
print(f"Workers setting: {global_args.workers}")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
# Import Gunicorn's StandaloneApplication
|
||||
@@ -128,31 +120,43 @@ def main():
|
||||
|
||||
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
||||
gunicorn_config.workers = (
|
||||
args.workers if args.workers else int(os.getenv("WORKERS", 1))
|
||||
global_args.workers
|
||||
if global_args.workers
|
||||
else int(os.getenv("WORKERS", 1))
|
||||
)
|
||||
|
||||
# Bind configuration prioritizes command line arguments
|
||||
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
|
||||
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
|
||||
host = (
|
||||
global_args.host
|
||||
if global_args.host != "0.0.0.0"
|
||||
else os.getenv("HOST", "0.0.0.0")
|
||||
)
|
||||
port = (
|
||||
global_args.port
|
||||
if global_args.port != 9621
|
||||
else int(os.getenv("PORT", 9621))
|
||||
)
|
||||
gunicorn_config.bind = f"{host}:{port}"
|
||||
|
||||
# Log level configuration prioritizes command line arguments
|
||||
gunicorn_config.loglevel = (
|
||||
args.log_level.lower()
|
||||
if args.log_level
|
||||
global_args.log_level.lower()
|
||||
if global_args.log_level
|
||||
else os.getenv("LOG_LEVEL", "info")
|
||||
)
|
||||
|
||||
# Timeout configuration prioritizes command line arguments
|
||||
gunicorn_config.timeout = (
|
||||
args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2))
|
||||
global_args.timeout
|
||||
if global_args.timeout * 2
|
||||
else int(os.getenv("TIMEOUT", 150 * 2))
|
||||
)
|
||||
|
||||
# Keepalive configuration
|
||||
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
||||
|
||||
# SSL configuration prioritizes command line arguments
|
||||
if args.ssl or os.getenv("SSL", "").lower() in (
|
||||
if global_args.ssl or os.getenv("SSL", "").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
@@ -160,12 +164,14 @@ def main():
|
||||
"on",
|
||||
):
|
||||
gunicorn_config.certfile = (
|
||||
args.ssl_certfile
|
||||
if args.ssl_certfile
|
||||
global_args.ssl_certfile
|
||||
if global_args.ssl_certfile
|
||||
else os.getenv("SSL_CERTFILE")
|
||||
)
|
||||
gunicorn_config.keyfile = (
|
||||
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
|
||||
global_args.ssl_keyfile
|
||||
if global_args.ssl_keyfile
|
||||
else os.getenv("SSL_KEYFILE")
|
||||
)
|
||||
|
||||
# Set configuration options from the module
|
||||
@@ -190,13 +196,13 @@ def main():
|
||||
# Import the application
|
||||
from lightrag.api.lightrag_server import get_application
|
||||
|
||||
return get_application(args)
|
||||
return get_application(global_args)
|
||||
|
||||
# Create the application
|
||||
app = GunicornApp("")
|
||||
|
||||
# Force workers to be an integer and greater than 1 for multi-process mode
|
||||
workers_count = int(args.workers)
|
||||
workers_count = int(global_args.workers)
|
||||
if workers_count > 1:
|
||||
# Set a flag to indicate we're in the main process
|
||||
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
||||
|
@@ -7,15 +7,13 @@ import argparse
|
||||
from typing import Optional, List, Tuple
|
||||
import sys
|
||||
from ascii_colors import ASCIIColors
|
||||
import logging
|
||||
from lightrag.api import __api_version__ as api_version
|
||||
from lightrag import __version__ as core_version
|
||||
from fastapi import HTTPException, Security, Request, status
|
||||
from dotenv import load_dotenv
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from .auth import auth_handler
|
||||
from ..prompt import PROMPTS
|
||||
from .config import ollama_server_infos, global_args
|
||||
|
||||
|
||||
def check_env_file():
|
||||
@@ -36,16 +34,8 @@ def check_env_file():
|
||||
return True
|
||||
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
global_args = {"main_args": None}
|
||||
|
||||
# Get whitelist paths from environment variable, only once during initialization
|
||||
default_whitelist = "/health,/api/*"
|
||||
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
|
||||
# Get whitelist paths from global_args, only once during initialization
|
||||
whitelist_paths = global_args.whitelist_paths.split(",")
|
||||
|
||||
# Pre-compile path matching patterns
|
||||
whitelist_patterns: List[Tuple[str, bool]] = []
|
||||
@@ -63,19 +53,6 @@ for path in whitelist_paths:
|
||||
auth_configured = bool(auth_handler.accounts)
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
|
||||
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
|
||||
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
"""
|
||||
Create a combined authentication dependency that implements authentication logic
|
||||
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
return combined_dependency
|
||||
|
||||
|
||||
class DefaultRAGStorageConfig:
|
||||
KV_STORAGE = "JsonKVStorage"
|
||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
default_hosts = {
|
||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||
}
|
||||
return default_hosts.get(
|
||||
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||
) # fallback to ollama if unknown
|
||||
|
||||
|
||||
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
||||
"""
|
||||
Get value from environment variable with type conversion
|
||||
|
||||
Args:
|
||||
env_key (str): Environment variable key
|
||||
default (any): Default value if env variable is not set
|
||||
value_type (type): Type to convert the value to
|
||||
|
||||
Returns:
|
||||
any: Converted value from environment or default
|
||||
"""
|
||||
value = os.getenv(env_key)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if value_type is bool:
|
||||
return value.lower() in ("true", "1", "yes", "t", "on")
|
||||
try:
|
||||
return value_type(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments with environment variable fallback
|
||||
|
||||
Args:
|
||||
is_uvicorn_mode: Whether running under uvicorn mode
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=get_env_value("HOST", "0.0.0.0"),
|
||||
help="Server host (default: from env or 0.0.0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=get_env_value("PORT", 9621, int),
|
||||
help="Server port (default: from env or 9621)",
|
||||
)
|
||||
|
||||
# Directory configuration
|
||||
parser.add_argument(
|
||||
"--working-dir",
|
||||
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
||||
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
default=get_env_value("INPUT_DIR", "./inputs"),
|
||||
help="Directory containing input documents (default: from env or ./inputs)",
|
||||
)
|
||||
|
||||
def timeout_type(value):
|
||||
if value is None:
|
||||
return 150
|
||||
if value is None or value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
default=get_env_value("TIMEOUT", None, timeout_type),
|
||||
type=timeout_type,
|
||||
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
||||
)
|
||||
|
||||
# RAG configuration
|
||||
parser.add_argument(
|
||||
"--max-async",
|
||||
type=int,
|
||||
default=get_env_value("MAX_ASYNC", 4, int),
|
||||
help="Maximum async operations (default: from env or 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=get_env_value("MAX_TOKENS", 32768, int),
|
||||
help="Maximum token size (default: from env or 32768)",
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default=get_env_value("LOG_LEVEL", "INFO"),
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Logging level (default: from env or INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=get_env_value("VERBOSE", False, bool),
|
||||
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
type=str,
|
||||
default=get_env_value("LIGHTRAG_API_KEY", None),
|
||||
help="API key for authentication. This protects lightrag server against unauthorized access",
|
||||
)
|
||||
|
||||
# Optional https parameters
|
||||
parser.add_argument(
|
||||
"--ssl",
|
||||
action="store_true",
|
||||
default=get_env_value("SSL", False, bool),
|
||||
help="Enable HTTPS (default: from env or False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
default=get_env_value("SSL_CERTFILE", None),
|
||||
help="Path to SSL certificate file (required if --ssl is enabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-keyfile",
|
||||
default=get_env_value("SSL_KEYFILE", None),
|
||||
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--history-turns",
|
||||
type=int,
|
||||
default=get_env_value("HISTORY_TURNS", 3, int),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=get_env_value("TOP_K", 60, int),
|
||||
help="Number of most similar results to return (default: from env or 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cosine-threshold",
|
||||
type=float,
|
||||
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||
)
|
||||
|
||||
# Ollama model name
|
||||
parser.add_argument(
|
||||
"--simulated-model-name",
|
||||
type=str,
|
||||
default=get_env_value(
|
||||
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
||||
),
|
||||
help="Number of conversation history turns to include (default: from env or 3)",
|
||||
)
|
||||
|
||||
# Namespace
|
||||
parser.add_argument(
|
||||
"--namespace-prefix",
|
||||
type=str,
|
||||
default=get_env_value("NAMESPACE_PREFIX", ""),
|
||||
help="Prefix of the namespace",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--auto-scan-at-startup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable automatic scanning when the program starts",
|
||||
)
|
||||
|
||||
# Server workers configuration
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=get_env_value("WORKERS", 1, int),
|
||||
help="Number of worker processes (default: from env or 1)",
|
||||
)
|
||||
|
||||
# LLM and embedding bindings
|
||||
parser.add_argument(
|
||||
"--llm-binding",
|
||||
type=str,
|
||||
default=get_env_value("LLM_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
||||
help="LLM binding type (default: from env or ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-binding",
|
||||
type=str,
|
||||
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "azure_openai"],
|
||||
help="Embedding binding type (default: from env or ollama)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
||||
if is_uvicorn_mode and args.workers > 1:
|
||||
original_workers = args.workers
|
||||
args.workers = 1
|
||||
# Log warning directly here
|
||||
logging.warning(
|
||||
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
||||
)
|
||||
|
||||
# convert relative path to absolute path
|
||||
args.working_dir = os.path.abspath(args.working_dir)
|
||||
args.input_dir = os.path.abspath(args.input_dir)
|
||||
|
||||
# Inject storage configuration from environment variables
|
||||
args.kv_storage = get_env_value(
|
||||
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||
)
|
||||
args.doc_status_storage = get_env_value(
|
||||
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||
)
|
||||
args.graph_storage = get_env_value(
|
||||
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||
)
|
||||
args.vector_storage = get_env_value(
|
||||
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||
)
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from environment
|
||||
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
|
||||
# Handle openai-ollama special case
|
||||
if args.llm_binding == "openai-ollama":
|
||||
args.llm_binding = "openai"
|
||||
args.embedding_binding = "ollama"
|
||||
|
||||
args.llm_binding_host = get_env_value(
|
||||
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
||||
)
|
||||
args.embedding_binding_host = get_env_value(
|
||||
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
||||
)
|
||||
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
||||
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
||||
|
||||
# Inject model configuration
|
||||
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
||||
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
||||
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
||||
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
||||
|
||||
# Inject chunk configuration
|
||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||
|
||||
# Inject LLM cache configuration
|
||||
args.enable_llm_cache_for_extract = get_env_value(
|
||||
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
||||
)
|
||||
|
||||
# Inject LLM temperature configuration
|
||||
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
|
||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||
|
||||
global_args["main_args"] = args
|
||||
return args
|
||||
|
||||
|
||||
def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
"""
|
||||
Display a colorful splash screen showing LightRAG server configuration
|
||||
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.white(" ├─ Workers: ", end="")
|
||||
ASCIIColors.yellow(f"{args.workers}")
|
||||
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
||||
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
|
||||
ASCIIColors.yellow(f"{args.cors_origins}")
|
||||
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
||||
ASCIIColors.yellow(f"{args.ssl}")
|
||||
if args.ssl:
|
||||
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.verbose}")
|
||||
ASCIIColors.white(" ├─ History Turns: ", end="")
|
||||
ASCIIColors.yellow(f"{args.history_turns}")
|
||||
ASCIIColors.white(" └─ API Key: ", end="")
|
||||
ASCIIColors.white(" ├─ API Key: ", end="")
|
||||
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
||||
ASCIIColors.white(" └─ JWT Auth: ", end="")
|
||||
ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
|
||||
|
||||
# Directory Configuration
|
||||
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
||||
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.embedding_dim}")
|
||||
|
||||
# RAG Configuration
|
||||
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
|
||||
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
||||
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
||||
ASCIIColors.yellow(f"{summary_language}")
|
||||
ASCIIColors.yellow(f"{args.summary_language}")
|
||||
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
||||
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
||||
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
protocol = "https" if args.ssl else "http"
|
||||
if args.host == "0.0.0.0":
|
||||
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
||||
ASCIIColors.white(" ├─ Local Access: ", end="")
|
||||
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
|
||||
ASCIIColors.white(" ├─ Remote Access: ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
|
||||
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
|
||||
ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="")
|
||||
ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
|
||||
ASCIIColors.white(" └─ WebUI (local): ", end="")
|
||||
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
|
||||
|
||||
ASCIIColors.yellow("\n📝 Note:")
|
||||
ASCIIColors.white(""" Since the server is running on 0.0.0.0:
|
||||
ASCIIColors.magenta("\n📝 Note:")
|
||||
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
|
||||
- Use 'localhost' or '127.0.0.1' for local access
|
||||
- Use your machine's IP address for remote access
|
||||
- To find your IP address:
|
||||
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
else:
|
||||
base_url = f"{protocol}://{args.host}:{args.port}"
|
||||
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
||||
ASCIIColors.white(" ├─ Base URL: ", end="")
|
||||
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
||||
ASCIIColors.yellow(f"{base_url}")
|
||||
ASCIIColors.white(" ├─ API Documentation: ", end="")
|
||||
ASCIIColors.yellow(f"{base_url}/docs")
|
||||
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
|
||||
ASCIIColors.yellow(f"{base_url}/redoc")
|
||||
|
||||
# Usage Examples
|
||||
ASCIIColors.magenta("\n📚 Quick Start Guide:")
|
||||
ASCIIColors.cyan("""
|
||||
1. Access the Swagger UI:
|
||||
Open your browser and navigate to the API documentation URL above
|
||||
|
||||
2. API Authentication:""")
|
||||
if args.key:
|
||||
ASCIIColors.cyan(""" Add the following header to your requests:
|
||||
X-API-Key: <your-api-key>
|
||||
""")
|
||||
else:
|
||||
ASCIIColors.cyan(" No authentication required\n")
|
||||
|
||||
ASCIIColors.cyan(""" 3. Basic Operations:
|
||||
- POST /upload_document: Upload new documents to RAG
|
||||
- POST /query: Query your document collection
|
||||
|
||||
4. Monitor the server:
|
||||
- Check server logs for detailed operation information
|
||||
- Use healthcheck endpoint: GET /health
|
||||
""")
|
||||
|
||||
# Security Notice
|
||||
if args.key:
|
||||
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
||||
ASCIIColors.white(""" API Key authentication is enabled.
|
||||
Make sure to include the X-API-Key header in all your requests.
|
||||
""")
|
||||
if args.auth_accounts:
|
||||
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
||||
ASCIIColors.white(""" JWT authentication is enabled.
|
||||
Make sure to login before making the request, and include the 'Authorization' in the header.
|
||||
""")
|
||||
|
||||
# Ensure splash output flush to system log
|
||||
sys.stdout.flush()
|
||||
|
1
lightrag/api/webui/assets/index-CD5HxTy1.css
generated
1
lightrag/api/webui/assets/index-CD5HxTy1.css
generated
File diff suppressed because one or more lines are too long
1345
lightrag/api/webui/assets/index-Cma7xY0-.js
generated
Normal file
1345
lightrag/api/webui/assets/index-Cma7xY0-.js
generated
Normal file
File diff suppressed because one or more lines are too long
1
lightrag/api/webui/assets/index-QU59h9JG.css
generated
Normal file
1
lightrag/api/webui/assets/index-QU59h9JG.css
generated
Normal file
File diff suppressed because one or more lines are too long
1321
lightrag/api/webui/assets/index-raheqJeu.js
generated
1321
lightrag/api/webui/assets/index-raheqJeu.js
generated
File diff suppressed because one or more lines are too long
4
lightrag/api/webui/index.html
generated
4
lightrag/api/webui/index.html
generated
@@ -8,8 +8,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lightrag</title>
|
||||
<script type="module" crossorigin src="/webui/assets/index-raheqJeu.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-CD5HxTy1.css">
|
||||
<script type="module" crossorigin src="/webui/assets/index-Cma7xY0-.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-QU59h9JG.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
131
lightrag/base.py
131
lightrag/base.py
@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
|
||||
async def index_done_callback(self) -> None:
|
||||
"""Commit the storage operations after indexing"""
|
||||
|
||||
@abstractmethod
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
|
||||
This abstract method defines the contract for dropping all data from a storage implementation.
|
||||
Each storage type must implement this method to:
|
||||
1. Clear all data from memory and/or external storage
|
||||
2. Remove any associated storage files if applicable
|
||||
3. Reset the storage to its initial state
|
||||
4. Handle cleanup of any resources
|
||||
5. Notify other processes if necessary
|
||||
6. This action should persistent the data to disk immediately.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message with the following format:
|
||||
{
|
||||
"status": str, # "success" or "error"
|
||||
"message": str # "data dropped" on success, error details on failure
|
||||
}
|
||||
|
||||
Implementation specific:
|
||||
- On success: return {"status": "success", "message": "data dropped"}
|
||||
- On failure: return {"status": "error", "message": "<error details>"}
|
||||
- If not supported: return {"status": "error", "message": "unsupported"}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Insert or update vectors in the storage."""
|
||||
"""Insert or update vectors in the storage.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name."""
|
||||
"""Delete a single entity by its name.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity."""
|
||||
"""Delete relations for a given entity.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(StorageNameSpace, ABC):
|
||||
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""Upsert data"""
|
||||
"""Upsert data
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed, or the cache mode is not supported
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get an edge by its source and target node ids."""
|
||||
"""Get node by its label identifier, return only node properties"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
"""Get all edges connected to a node."""
|
||||
"""Get edge properties between two nodes"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""Delete a node from the graph."""
|
||||
"""Delete a node from the graph.
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 3
|
||||
self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
|
||||
) -> KnowledgeGraph:
|
||||
"""Retrieve a subgraph of the knowledge graph starting from a given node."""
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node,* means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
|
||||
|
||||
class DocStatus(str, Enum):
|
||||
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific status"""
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Drop cache is not supported for Doc Status storage"""
|
||||
return False
|
||||
|
||||
|
||||
class StoragesStatus(str, Enum):
|
||||
"""Storages status"""
|
||||
|
@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"KV_STORAGE": {
|
||||
"implementations": [
|
||||
"JsonKVStorage",
|
||||
"MongoKVStorage",
|
||||
"RedisKVStorage",
|
||||
"TiDBKVStorage",
|
||||
"PGKVStorage",
|
||||
"OracleKVStorage",
|
||||
"MongoKVStorage",
|
||||
# "TiDBKVStorage",
|
||||
],
|
||||
"required_methods": ["get_by_id", "upsert"],
|
||||
},
|
||||
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"implementations": [
|
||||
"NetworkXStorage",
|
||||
"Neo4JStorage",
|
||||
"MongoGraphStorage",
|
||||
"TiDBGraphStorage",
|
||||
"AGEStorage",
|
||||
"GremlinStorage",
|
||||
"PGGraphStorage",
|
||||
"OracleGraphStorage",
|
||||
# "AGEStorage",
|
||||
# "MongoGraphStorage",
|
||||
# "TiDBGraphStorage",
|
||||
# "GremlinStorage",
|
||||
],
|
||||
"required_methods": ["upsert_node", "upsert_edge"],
|
||||
},
|
||||
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"NanoVectorDBStorage",
|
||||
"MilvusVectorDBStorage",
|
||||
"ChromaVectorDBStorage",
|
||||
"TiDBVectorDBStorage",
|
||||
"PGVectorStorage",
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
# "TiDBVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"implementations": [
|
||||
"JsonDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"PGDocStatusStorage",
|
||||
"MongoDocStatusStorage",
|
||||
],
|
||||
"required_methods": ["get_docs_by_status"],
|
||||
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
"RedisKVStorage": ["REDIS_URI"],
|
||||
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"OracleKVStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Graph Storage Implementations
|
||||
"NetworkXStorage": [],
|
||||
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
||||
"MongoGraphStorage": [],
|
||||
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"AGEStorage": [
|
||||
"AGE_POSTGRES_DB",
|
||||
"AGE_POSTGRES_USER",
|
||||
"AGE_POSTGRES_PASSWORD",
|
||||
],
|
||||
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
# "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
||||
"PGGraphStorage": [
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DATABASE",
|
||||
],
|
||||
"OracleGraphStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
# Vector Storage Implementations
|
||||
"NanoVectorDBStorage": [],
|
||||
"MilvusVectorDBStorage": [],
|
||||
"ChromaVectorDBStorage": [],
|
||||
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
# "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
||||
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
"FaissVectorDBStorage": [],
|
||||
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
||||
"OracleVectorDBStorage": [
|
||||
"ORACLE_DSN",
|
||||
"ORACLE_USER",
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
@@ -112,9 +90,6 @@ STORAGES = {
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
@@ -122,14 +97,14 @@ STORAGES = {
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
# "TiDBKVStorage": ".kg.tidb_impl",
|
||||
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
# "TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
# "GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
"FaissVectorDBStorage": ".kg.faiss_impl",
|
||||
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
||||
|
@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
import psycopg # type: ignore
|
||||
from psycopg.rows import namedtuple_row # type: ignore
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
|
||||
async def index_done_callback(self) -> None:
|
||||
# AGES handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all nodes and relationships in the graph.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
query = """
|
||||
MATCH (n)
|
||||
DETACH DELETE n
|
||||
"""
|
||||
await self._query(query)
|
||||
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
@@ -10,8 +11,8 @@ import pipmaster as pm
|
||||
if not pm.is_installed("chromadb"):
|
||||
pm.install("chromadb")
|
||||
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
from chromadb import HttpClient, PersistentClient # type: ignore
|
||||
from chromadb.config import Settings # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all documents from the ChromaDB collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Get all IDs in the collection
|
||||
result = self._collection.get(include=[])
|
||||
if result and result["ids"] and len(result["ids"]) > 0:
|
||||
# Delete all documents
|
||||
self._collection.delete(ids=result["ids"])
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -11,16 +11,20 @@ import pipmaster as pm
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
from lightrag.base import BaseVectorStorage
|
||||
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
import faiss # type: ignore
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
)
|
||||
|
||||
import faiss # type: ignore
|
||||
|
||||
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
|
||||
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
|
||||
|
||||
if not pm.is_installed(FAISS_PACKAGE):
|
||||
pm.install(FAISS_PACKAGE)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
async def delete(self, ids: list[str]):
|
||||
"""
|
||||
Delete vectors for the provided custom IDs.
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||
to_remove = []
|
||||
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||
await self.delete([entity_id])
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Delete relations for a given entity by scanning metadata.
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
logger.debug(f"Searching relations for entity {entity_name}")
|
||||
relations = []
|
||||
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
results.append({**metadata, "id": metadata.get("__id__")})
|
||||
|
||||
return results
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method will remove all vectors from the Faiss index and delete the storage files.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# Reset the index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
|
||||
# Remove storage files if they exist
|
||||
if os.path.exists(self._faiss_index_file):
|
||||
os.remove(self._faiss_index_file)
|
||||
if os.path.exists(self._meta_file):
|
||||
os.remove(self._meta_file)
|
||||
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
|
||||
# Notify other processes
|
||||
await set_all_update_flags(self.namespace)
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
|
||||
if not pm.is_installed("gremlinpython"):
|
||||
pm.install("gremlinpython")
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
from gremlin_python.driver import client, serializer # type: ignore
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
|
||||
from gremlin_python.driver.protocol import GremlinServerError # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge deletion: {str(e)}")
|
||||
raise
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all nodes and relationships in the graph.
|
||||
|
||||
This function deletes all nodes with the specified graph name property,
|
||||
which automatically removes all associated edges.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.drop()
|
||||
"""
|
||||
await self._query(query)
|
||||
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
await clear_all_update_flags(self.namespace)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
async with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
async with self._storage_lock:
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await set_all_update_flags(self.namespace)
|
||||
await self.index_done_callback()
|
||||
async def delete(self, doc_ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
any_deleted = False
|
||||
for doc_id in doc_ids:
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all document status data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Clear all document status data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete specific records from storage by their IDs
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of document IDs to be deleted from storage
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
any_deleted = False
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
result = self._data.pop(doc_id, None)
|
||||
if result is not None:
|
||||
any_deleted = True
|
||||
|
||||
if any_deleted:
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
|
||||
Importance notes for in-memory storage:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. update flags to notify other processes that data persistence is needed
|
||||
|
||||
Args:
|
||||
ids (list[str]): List of cache mode to be drop from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self.delete(modes)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
This action will persistent the data to disk immediately.
|
||||
|
||||
This method will:
|
||||
1. Clear all data from memory
|
||||
2. Update flags to notify other processes
|
||||
3. Trigger index_done_callback to save the empty state
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
|
||||
import configparser
|
||||
from pymilvus import MilvusClient
|
||||
from pymilvus import MilvusClient # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all data from the Milvus collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Drop the collection and recreate it
|
||||
if self._client.has_collection(self.namespace):
|
||||
self._client.drop_collection(self.namespace)
|
||||
|
||||
# Recreate the collection
|
||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from motor.motor_asyncio import (
|
||||
from motor.motor_asyncio import ( # type: ignore
|
||||
AsyncIOMotorClient,
|
||||
AsyncIOMotorDatabase,
|
||||
AsyncIOMotorCollection,
|
||||
)
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
from pymongo.operations import SearchIndexModel # type: ignore
|
||||
from pymongo.errors import PyMongoError # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete documents with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of document IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self._data.delete_many({"_id": {"$in": ids}})
|
||||
logger.info(
|
||||
f"Deleted {result.deleted_count} documents from {self.namespace}"
|
||||
)
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting documents from {self.namespace}: {e}")
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build regex pattern to match documents with the specified modes
|
||||
pattern = f"^({'|'.join(modes)})_"
|
||||
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
||||
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||
# Mongo handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
|
||||
logger.debug(f"Successfully deleted edges: {edges}")
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all documents in the collection and recreating vector index.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
try:
|
||||
# Delete all documents
|
||||
result = await self._data.delete_many({})
|
||||
deleted_count = result.deleted_count
|
||||
|
||||
# Recreate vector index
|
||||
await self.create_vector_index_if_not_exists()
|
||||
|
||||
logger.info(
|
||||
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped and vector index recreated",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
return self._client
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
storage = getattr(client, "_NanoVectorDB__storage")
|
||||
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
client = await self._get_client()
|
||||
return client.get(ids)
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the vector database storage file if it exists
|
||||
2. Reinitialize the vector database client
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
This method is intended for use in scenarios where all data needs to be removed,
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._client_file_name):
|
||||
os.remove(self._client_file_name)
|
||||
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final, Optional
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
||||
USERNAME = os.environ.get(
|
||||
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
),
|
||||
)
|
||||
DATABASE = os.environ.get(
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
||||
)
|
||||
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
@@ -98,22 +101,16 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
||||
)
|
||||
|
||||
# Try to connect to the database
|
||||
with GraphDatabase.driver(
|
||||
URI,
|
||||
auth=(USERNAME, PASSWORD),
|
||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||
connection_timeout=CONNECTION_TIMEOUT,
|
||||
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
||||
) as _sync_driver:
|
||||
# Try to connect to the database and create it if it doesn't exist
|
||||
for database in (DATABASE, None):
|
||||
self._DATABASE = database
|
||||
connected = False
|
||||
|
||||
try:
|
||||
with _sync_driver.session(database=database) as session:
|
||||
async with self._driver.session(database=database) as session:
|
||||
try:
|
||||
session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"Connected to {database} at {URI}")
|
||||
connected = True
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
@@ -130,10 +127,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||
)
|
||||
try:
|
||||
with _sync_driver.session() as session:
|
||||
session.run(
|
||||
async with self._driver.session() as session:
|
||||
result = await session.run(
|
||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||
)
|
||||
await result.consume() # Ensure result is consumed
|
||||
logger.info(f"{database} at {URI} created".capitalize())
|
||||
connected = True
|
||||
except (
|
||||
@@ -143,9 +141,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
) or (
|
||||
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||
):
|
||||
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
||||
if database is not None:
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||
@@ -155,14 +151,42 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
raise e
|
||||
|
||||
if connected:
|
||||
# Create index for base nodes on entity_id if it doesn't exist
|
||||
try:
|
||||
async with self._driver.session(database=database) as session:
|
||||
# Check if index exists first
|
||||
check_query = """
|
||||
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
||||
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
|
||||
RETURN count(*) > 0 AS exists
|
||||
"""
|
||||
try:
|
||||
check_result = await session.run(check_query)
|
||||
record = await check_result.single()
|
||||
await check_result.consume()
|
||||
|
||||
index_exists = record and record.get("exists", False)
|
||||
|
||||
if not index_exists:
|
||||
# Create index only if it doesn't exist
|
||||
result = await session.run(
|
||||
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
logger.info(
|
||||
f"Created index for base nodes on entity_id in {database}"
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if db.indexes() is not supported in this Neo4j version
|
||||
result = await session.run(
|
||||
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
|
||||
)
|
||||
await result.consume()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create index: {str(e)}")
|
||||
break
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
async def finalize(self):
|
||||
"""Close the Neo4j driver and release all resources"""
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Ensure driver is closed when context manager exits"""
|
||||
await self.close()
|
||||
await self.finalize()
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Noe4J handles persistence automatically
|
||||
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
raise
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""Get node by its label identifier.
|
||||
"""Get node by its label identifier, return only node properties
|
||||
|
||||
Args:
|
||||
node_id: The node label to look up
|
||||
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
||||
)
|
||||
# Return default edge properties when no edge found
|
||||
return {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}
|
||||
# Return None when no edge found
|
||||
return None
|
||||
finally:
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"""
|
||||
properties = node_data
|
||||
entity_type = properties["entity_type"]
|
||||
entity_id = properties["entity_id"]
|
||||
if "entity_id" not in properties:
|
||||
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
||||
|
||||
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = (
|
||||
"""
|
||||
MERGE (n:base {entity_id: $properties.entity_id})
|
||||
MERGE (n:base {entity_id: $entity_id})
|
||||
SET n += $properties
|
||||
SET n:`%s`
|
||||
"""
|
||||
% entity_type
|
||||
)
|
||||
result = await tx.run(query, properties=properties)
|
||||
result = await tx.run(
|
||||
query, entity_id=node_id, properties=properties
|
||||
)
|
||||
logger.debug(
|
||||
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
|
||||
f"Upserted node with entity_id '{node_id}' and properties: {properties}"
|
||||
)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
min_degree: Minimum degree of nodes to include. Defaults to 0
|
||||
inclusive: Do an inclusive search if true
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Complete connected subgraph for specified node
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
) as session:
|
||||
try:
|
||||
if node_label == "*":
|
||||
# First check total node count to determine if graph is truncated
|
||||
count_query = "MATCH (n) RETURN count(n) as total"
|
||||
count_result = None
|
||||
try:
|
||||
count_result = await session.run(count_query)
|
||||
count_record = await count_result.single()
|
||||
|
||||
if count_record and count_record["total"] > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
finally:
|
||||
if count_result:
|
||||
await count_result.consume()
|
||||
|
||||
# Run main query to get nodes with highest degree
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
WITH n, COALESCE(count(r), 0) AS degree
|
||||
WHERE degree >= $min_degree
|
||||
ORDER BY degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: n}) AS filtered_nodes
|
||||
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
"""
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
|
||||
{"max_nodes": max_nodes},
|
||||
)
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
|
||||
else:
|
||||
# Main query uses partial matching
|
||||
main_query = """
|
||||
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
# First try without limit to check if we need to truncate
|
||||
full_query = """
|
||||
MATCH (start)
|
||||
WHERE
|
||||
CASE
|
||||
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
|
||||
ELSE start.entity_id = $entity_id
|
||||
END
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
relationshipFilter: '',
|
||||
@@ -688,40 +721,79 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
bfs: true
|
||||
})
|
||||
YIELD nodes, relationships
|
||||
WITH start, nodes, relationships
|
||||
WITH nodes, relationships, size(nodes) AS total_nodes
|
||||
UNWIND nodes AS node
|
||||
OPTIONAL MATCH (node)-[r]-()
|
||||
WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
|
||||
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
|
||||
ORDER BY
|
||||
CASE
|
||||
WHEN node = start THEN 3
|
||||
WHEN EXISTS((start)--(node)) THEN 2
|
||||
ELSE 1
|
||||
END DESC,
|
||||
degree DESC
|
||||
LIMIT $max_nodes
|
||||
WITH collect({node: node}) AS filtered_nodes
|
||||
UNWIND filtered_nodes AS node_info
|
||||
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
||||
OPTIONAL MATCH (a)-[r]-(b)
|
||||
WHERE a IN kept_nodes AND b IN kept_nodes
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
||||
RETURN node_info, relationships, total_nodes
|
||||
"""
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
|
||||
# Try to get full result
|
||||
full_result = None
|
||||
try:
|
||||
full_result = await session.run(
|
||||
full_query,
|
||||
{
|
||||
"max_nodes": MAX_GRAPH_NODES,
|
||||
"entity_id": node_label,
|
||||
"inclusive": inclusive,
|
||||
"max_depth": max_depth,
|
||||
"min_degree": min_degree,
|
||||
},
|
||||
)
|
||||
full_record = await full_result.single()
|
||||
|
||||
# If no record found, return empty KnowledgeGraph
|
||||
if not full_record:
|
||||
logger.debug(f"No nodes found for entity_id: {node_label}")
|
||||
return result
|
||||
|
||||
# If record found, check node count
|
||||
total_nodes = full_record["total_nodes"]
|
||||
|
||||
if total_nodes <= max_nodes:
|
||||
# If node count is within limit, use full result directly
|
||||
logger.debug(
|
||||
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
||||
)
|
||||
record = full_record
|
||||
else:
|
||||
# If node count exceeds limit, set truncated flag and run limited query
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
||||
)
|
||||
|
||||
# Run limited query
|
||||
limited_query = """
|
||||
MATCH (start)
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
relationshipFilter: '',
|
||||
minLevel: 0,
|
||||
maxLevel: $max_depth,
|
||||
limit: $max_nodes,
|
||||
bfs: true
|
||||
})
|
||||
YIELD nodes, relationships
|
||||
UNWIND nodes AS node
|
||||
WITH collect({node: node}) AS node_info, relationships
|
||||
RETURN node_info, relationships
|
||||
"""
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
limited_query,
|
||||
{
|
||||
"entity_id": node_label,
|
||||
"max_depth": max_depth,
|
||||
"max_nodes": max_nodes,
|
||||
},
|
||||
)
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
finally:
|
||||
if full_result:
|
||||
await full_result.consume()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
@@ -756,10 +828,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
finally:
|
||||
await result_set.consume() # Ensure result set is consumed
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
logger.warning(f"APOC plugin error: {str(e)}")
|
||||
@@ -767,110 +837,28 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.warning(
|
||||
"Neo4j: falling back to basic Cypher recursive search..."
|
||||
)
|
||||
if inclusive:
|
||||
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
else:
|
||||
logger.warning(
|
||||
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
||||
)
|
||||
return await self._robust_fallback(
|
||||
node_label, max_depth, min_degree
|
||||
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _robust_fallback(
|
||||
self, node_label: str, max_depth: int, min_degree: int = 0
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Fallback implementation when APOC plugin is not available or incompatible.
|
||||
This method implements the same functionality as get_knowledge_graph but uses
|
||||
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
||||
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
result = KnowledgeGraph()
|
||||
visited_nodes = set()
|
||||
visited_edges = set()
|
||||
|
||||
async def traverse(
|
||||
node: KnowledgeGraphNode,
|
||||
edge: Optional[KnowledgeGraphEdge],
|
||||
current_depth: int,
|
||||
):
|
||||
# Check traversal limits
|
||||
if current_depth > max_depth:
|
||||
logger.debug(f"Reached max depth: {max_depth}")
|
||||
return
|
||||
if len(visited_nodes) >= MAX_GRAPH_NODES:
|
||||
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
||||
return
|
||||
|
||||
# Check if node already visited
|
||||
if node.id in visited_nodes:
|
||||
return
|
||||
|
||||
# Get all edges and target nodes
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(
|
||||
1000
|
||||
) # Max neighbour nodes we can handled
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Nodes not connected to start node need to check degree
|
||||
if current_depth > 1 and len(records) < min_degree:
|
||||
return
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(node)
|
||||
visited_nodes.add(node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if edge and edge.id not in visited_edges:
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge.id)
|
||||
|
||||
# Prepare nodes and edges for recursive processing
|
||||
nodes_to_process = []
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=list(f"{target_id}"),
|
||||
properties=dict(b_node.properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
nodes_to_process.append((target_node, target_edge))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
||||
)
|
||||
|
||||
# Process nodes after releasing database connection
|
||||
for target_node, target_edge in nodes_to_process:
|
||||
await traverse(target_node, target_edge, current_depth + 1)
|
||||
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
|
||||
|
||||
# Get the starting node's data
|
||||
async with self._driver.session(
|
||||
@@ -889,15 +877,129 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=f"{node_record['n'].get('entity_id')}",
|
||||
labels=list(f"{node_record['n'].get('entity_id')}"),
|
||||
properties=dict(node_record["n"].properties),
|
||||
labels=[node_record["n"].get("entity_id")],
|
||||
properties=dict(node_record["n"]._properties),
|
||||
)
|
||||
finally:
|
||||
await node_result.consume() # Ensure results are consumed
|
||||
|
||||
# Start traversal with the initial node
|
||||
await traverse(start_node, None, 0)
|
||||
# Initialize queue for BFS with (node, edge, depth) tuples
|
||||
# edge is None for the starting node
|
||||
queue = deque([(start_node, None, 0)])
|
||||
|
||||
# True BFS implementation using a queue
|
||||
while queue and len(visited_nodes) < max_nodes:
|
||||
# Dequeue the next node to process
|
||||
current_node, current_edge, current_depth = queue.popleft()
|
||||
|
||||
# Skip if already visited or exceeds max depth
|
||||
if current_node.id in visited_nodes:
|
||||
continue
|
||||
|
||||
if current_depth > max_depth:
|
||||
logger.debug(
|
||||
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(current_node)
|
||||
visited_nodes.add(current_node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if current_edge and current_edge.id not in visited_edges:
|
||||
result.edges.append(current_edge)
|
||||
visited_edges.add(current_edge.id)
|
||||
|
||||
# Stop if we've reached the node limit
|
||||
if len(visited_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||
)
|
||||
break
|
||||
|
||||
# Get all edges and target nodes for the current node (even at max_depth)
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=current_node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(1000) # Max neighbor nodes we can handle
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Process all neighbors - capture all edges but only queue unvisited nodes
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=[target_id],
|
||||
properties=dict(b_node._properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{current_node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
# 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
|
||||
sorted_pair = tuple(sorted([current_node.id, target_id]))
|
||||
|
||||
# 检查是否已存在相同的边(考虑无向性)
|
||||
if sorted_pair not in visited_edge_pairs:
|
||||
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
|
||||
if target_id in visited_nodes or (
|
||||
target_id not in visited_nodes
|
||||
and current_depth < max_depth
|
||||
):
|
||||
result.edges.append(target_edge)
|
||||
visited_edges.add(edge_id)
|
||||
visited_edge_pairs.add(sorted_pair)
|
||||
|
||||
# Only add unvisited nodes to the queue for further expansion
|
||||
if target_id not in visited_nodes:
|
||||
# Only add to queue if we're not at max depth yet
|
||||
if current_depth < max_depth:
|
||||
# Add node to queue with incremented depth
|
||||
# Edge is already added to result, so we pass None as edge
|
||||
queue.append((target_node, None, current_depth + 1))
|
||||
else:
|
||||
# At max depth, we've already added the edge but we don't add the node
|
||||
# This prevents adding nodes beyond max_depth to the result
|
||||
logger.debug(
|
||||
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
|
||||
)
|
||||
else:
|
||||
# If target node already exists in result, we don't need to add it again
|
||||
logger.debug(
|
||||
f"Node {target_id} already visited, edge added but node not queued"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing entity_id on target node"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
# Method 2: Query compatible with older versions
|
||||
query = """
|
||||
MATCH (n)
|
||||
MATCH (n:base)
|
||||
WHERE n.entity_id IS NOT NULL
|
||||
RETURN DISTINCT n.entity_id AS label
|
||||
ORDER BY label
|
||||
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all data from storage and clean up resources
|
||||
|
||||
This method will delete all nodes and relationships in the Neo4j database.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
# Delete all nodes and relationships
|
||||
query = "MATCH (n) DETACH DELETE n"
|
||||
result = await session.run(query)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
# TODO:deprecated, remove later
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
graph.remove_node(node_id)
|
||||
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
# TODO: NOT USED
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
Importance notes:
|
||||
1. Changes will be persisted to disk during the next index_done_callback
|
||||
2. Only one process should updating the storage at a time before index_done_callback,
|
||||
KG-storage-log should be used to avoid data corruption
|
||||
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
min_degree: Minimum degree of nodes to include. Defaults to 0
|
||||
inclusive: Do an inclusive search if true
|
||||
node_label: Label of the starting node,* means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Initialize sets for start nodes and direct connected nodes
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
result = KnowledgeGraph()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = (
|
||||
graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
else:
|
||||
# Find nodes with matching node id based on search_mode
|
||||
nodes_to_explore = []
|
||||
for n, attr in graph.nodes(data=True):
|
||||
node_str = str(n)
|
||||
if not inclusive:
|
||||
if node_label == node_str: # Use exact matching
|
||||
nodes_to_explore.append(n)
|
||||
else: # inclusive mode
|
||||
if node_label in node_str: # Use partial matching
|
||||
nodes_to_explore.append(n)
|
||||
# Get degrees of all nodes
|
||||
degrees = dict(graph.degree())
|
||||
# Sort nodes by degree in descending order and take top max_nodes
|
||||
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
if not nodes_to_explore:
|
||||
logger.warning(f"No nodes found with label {node_label}")
|
||||
return result
|
||||
|
||||
# Get subgraph using ego_graph from all matching nodes
|
||||
combined_subgraph = nx.Graph()
|
||||
for start_node in nodes_to_explore:
|
||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||
|
||||
# Get start nodes and direct connected nodes
|
||||
if nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(
|
||||
combined_subgraph.neighbors(start_node)
|
||||
)
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
|
||||
subgraph = combined_subgraph
|
||||
|
||||
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||
if min_degree > 0:
|
||||
nodes_to_keep = [
|
||||
node
|
||||
for node, degree in subgraph.degree()
|
||||
if node in start_nodes
|
||||
or node in direct_connected_nodes
|
||||
or degree >= min_degree
|
||||
]
|
||||
subgraph = subgraph.subgraph(nodes_to_keep)
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
node_degrees = dict(subgraph.degree())
|
||||
|
||||
def priority_key(node_item):
|
||||
node, degree = node_item
|
||||
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
||||
if node in start_nodes:
|
||||
priority = 2
|
||||
elif node in direct_connected_nodes:
|
||||
priority = 1
|
||||
else:
|
||||
priority = 0
|
||||
return (priority, degree)
|
||||
|
||||
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
|
||||
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
|
||||
:MAX_GRAPH_NODES
|
||||
]
|
||||
top_node_ids = [node[0] for node in top_nodes]
|
||||
# Create new subgraph and keep nodes only with most degree
|
||||
subgraph = subgraph.subgraph(top_node_ids)
|
||||
# Check if graph is truncated
|
||||
if len(sorted_nodes) > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
|
||||
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
|
||||
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
|
||||
# Create subgraph with the highest degree nodes
|
||||
subgraph = graph.subgraph(limited_nodes)
|
||||
else:
|
||||
# Check if node exists
|
||||
if node_label not in graph:
|
||||
logger.warning(f"Node {node_label} not found in the graph")
|
||||
return KnowledgeGraph() # Return empty graph
|
||||
|
||||
# Use BFS to get nodes
|
||||
bfs_nodes = []
|
||||
visited = set()
|
||||
queue = [(node_label, 0)] # (node, depth) tuple
|
||||
|
||||
# Breadth-first search
|
||||
while queue and len(bfs_nodes) < max_nodes:
|
||||
current, depth = queue.pop(0)
|
||||
if current not in visited:
|
||||
visited.add(current)
|
||||
bfs_nodes.append(current)
|
||||
|
||||
# Only explore neighbors if we haven't reached max_depth
|
||||
if depth < max_depth:
|
||||
# Add neighbor nodes to queue with incremented depth
|
||||
neighbors = list(graph.neighbors(current))
|
||||
queue.extend(
|
||||
[(n, depth + 1) for n in neighbors if n not in visited]
|
||||
)
|
||||
|
||||
# Check if graph is truncated - if we still have nodes in the queue
|
||||
# and we've reached max_nodes, then the graph is truncated
|
||||
if queue and len(bfs_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
|
||||
)
|
||||
|
||||
# Create subgraph with BFS discovered nodes
|
||||
subgraph = graph.subgraph(bfs_nodes)
|
||||
|
||||
# Add nodes to result
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
for node in subgraph.nodes():
|
||||
if str(node) in seen_nodes:
|
||||
continue
|
||||
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
for edge in subgraph.edges():
|
||||
source, target = edge
|
||||
# Esure unique edge_id for undirect graph
|
||||
if source > target:
|
||||
if str(source) > str(target):
|
||||
source, target = target, source
|
||||
edge_id = f"{source}-{target}"
|
||||
if edge_id in seen_edges:
|
||||
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return False # Return error
|
||||
|
||||
return True
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all graph data from storage and clean up resources
|
||||
|
||||
This method will:
|
||||
1. Remove the graph storage file if it exists
|
||||
2. Reset the graph to an empty state
|
||||
3. Update flags to notify other processes
|
||||
4. Changes is persisted to disk immediately
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
async with self._storage_lock:
|
||||
# delete _client_file_name
|
||||
if os.path.exists(self._graphml_xml_file):
|
||||
os.remove(self._graphml_xml_file)
|
||||
self._graph = nx.Graph()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
self.storage_updated.value = False
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping graph {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -8,17 +8,15 @@ import uuid
|
||||
from ..utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
import configparser
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("qdrant-client"):
|
||||
pm.install("qdrant-client")
|
||||
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client import QdrantClient, models # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
def compute_mdhash_id_for_qdrant(
|
||||
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
||||
return []
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get vector data by its ID
|
||||
|
||||
Args:
|
||||
id: The unique identifier of the vector
|
||||
|
||||
Returns:
|
||||
The vector data if found, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Convert to Qdrant compatible ID
|
||||
qdrant_id = compute_mdhash_id_for_qdrant(id)
|
||||
|
||||
# Retrieve the point by ID
|
||||
result = self._client.retrieve(
|
||||
collection_name=self.namespace,
|
||||
ids=[qdrant_id],
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0].payload
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
"""Get multiple vector data by their IDs
|
||||
|
||||
Args:
|
||||
ids: List of unique identifiers
|
||||
|
||||
Returns:
|
||||
List of vector data objects that were found
|
||||
"""
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Convert to Qdrant compatible IDs
|
||||
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
|
||||
|
||||
# Retrieve the points by IDs
|
||||
results = self._client.retrieve(
|
||||
collection_name=self.namespace,
|
||||
ids=qdrant_ids,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
return [point.payload for point in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||
return []
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop all vector data from storage and clean up resources
|
||||
|
||||
This method will delete all data from the Qdrant collection.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Operation status and message
|
||||
- On success: {"status": "success", "message": "data dropped"}
|
||||
- On failure: {"status": "error", "message": "<error details>"}
|
||||
"""
|
||||
try:
|
||||
# Delete the collection and recreate it
|
||||
if self._client.collection_exists(self.namespace):
|
||||
self._client.delete_collection(self.namespace)
|
||||
|
||||
# Recreate the collection
|
||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
vectors_config=models.VectorParams(
|
||||
size=self.embedding_func.embedding_dim,
|
||||
distance=models.Distance.COSINE,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
|
||||
)
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
@@ -12,6 +12,7 @@ if not pm.is_installed("redis"):
|
||||
from redis.asyncio import Redis, ConnectionPool
|
||||
from redis.exceptions import RedisError, ConnectionError
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
|
||||
from lightrag.base import BaseKVStorage
|
||||
import json
|
||||
|
||||
@@ -122,6 +123,10 @@ class RedisKVStorage(BaseKVStorage):
|
||||
logger.error(f"JSON encode error during upsert: {e}")
|
||||
raise
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Redis handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete entries with specified IDs"""
|
||||
if not ids:
|
||||
@@ -138,71 +143,52 @@ class RedisKVStorage(BaseKVStorage):
|
||||
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete an entity by name"""
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by by cache mode
|
||||
|
||||
Importance notes for Redis storage:
|
||||
1. This will immediately delete the specified cache modes from Redis
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache mode to be drop from storage
|
||||
|
||||
Returns:
|
||||
True: if the cache drop successfully
|
||||
False: if the cache drop failed
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
await self.delete(modes)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage by removing all keys under the current namespace.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||
"""
|
||||
async with self._get_redis_connection() as redis:
|
||||
result = await redis.delete(f"{self.namespace}:{entity_id}")
|
||||
|
||||
if result:
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete all relations associated with an entity"""
|
||||
try:
|
||||
async with self._get_redis_connection() as redis:
|
||||
cursor = 0
|
||||
relation_keys = []
|
||||
pattern = f"{self.namespace}:*"
|
||||
keys = await redis.keys(f"{self.namespace}:*")
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=pattern)
|
||||
|
||||
# Process keys in batches
|
||||
if keys:
|
||||
pipe = redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
pipe.delete(key)
|
||||
results = await pipe.execute()
|
||||
deleted_count = sum(results)
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if (
|
||||
data.get("src_id") == entity_name
|
||||
or data.get("tgt_id") == entity_name
|
||||
):
|
||||
relation_keys.append(key)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON in key {key}")
|
||||
continue
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Delete relations in batches
|
||||
if relation_keys:
|
||||
# Delete in chunks to avoid too many arguments
|
||||
chunk_size = 1000
|
||||
for i in range(0, len(relation_keys), chunk_size):
|
||||
chunk = relation_keys[i:i + chunk_size]
|
||||
deleted = await redis.delete(*chunk)
|
||||
logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
|
||||
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
||||
return {"status": "success", "message": f"{deleted_count} keys dropped"}
|
||||
else:
|
||||
logger.debug(f"No relations found for entity {entity_name}")
|
||||
logger.info(f"No keys found to drop in {self.namespace}")
|
||||
return {"status": "success", "message": "no keys to drop"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Redis handles persistence automatically
|
||||
pass
|
||||
|
@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy import create_engine, text # type: ignore
|
||||
|
||||
|
||||
class TiDB:
|
||||
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete records with specified IDs from the storage.
|
||||
|
||||
Args:
|
||||
ids: List of record IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
id_field = namespace_to_id(self.namespace)
|
||||
|
||||
if not table_name or not id_field:
|
||||
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
||||
return
|
||||
|
||||
ids_list = ",".join([f"'{id}'" for id in ids])
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
||||
|
||||
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
||||
logger.info(
|
||||
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting records from {self.namespace}: {e}")
|
||||
|
||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||
"""Delete specific records from storage by cache mode
|
||||
|
||||
Args:
|
||||
modes (list[str]): List of cache modes to be dropped from storage
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not modes:
|
||||
return False
|
||||
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return False
|
||||
|
||||
if table_name != "LIGHTRAG_LLM_CACHE":
|
||||
return False
|
||||
|
||||
# 构建MySQL风格的IN查询
|
||||
modes_list = ", ".join([f"'{mode}'" for mode in modes])
|
||||
sql = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE workspace = :workspace
|
||||
AND mode IN ({modes_list})
|
||||
"""
|
||||
|
||||
logger.info(f"Deleting cache by modes: {modes}")
|
||||
await self.db.execute(sql, {"workspace": self.db.workspace})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
||||
return False
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unknown namespace: {self.namespace}",
|
||||
}
|
||||
|
||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||
table_name=table_name
|
||||
)
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
"""Delete vectors with specified IDs from the storage.
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
id_field = namespace_to_id(self.namespace)
|
||||
|
||||
if not table_name or not id_field:
|
||||
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
||||
return
|
||||
|
||||
ids_list = ",".join([f"'{id}'" for id in ids])
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
||||
|
||||
try:
|
||||
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete an entity by its name from the vector storage.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity to delete
|
||||
"""
|
||||
try:
|
||||
# Construct SQL to delete the entity
|
||||
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
|
||||
WHERE workspace = :workspace AND name = :entity_name"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
"""Delete all relations associated with an entity.
|
||||
|
||||
Args:
|
||||
entity_name: The name of the entity whose relations should be deleted
|
||||
"""
|
||||
try:
|
||||
# Delete relations where the entity is either the source or target
|
||||
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
|
||||
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
|
||||
|
||||
await self.db.execute(
|
||||
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
||||
)
|
||||
logger.debug(f"Successfully deleted relations for entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unknown namespace: {self.namespace}",
|
||||
}
|
||||
|
||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||
table_name=table_name
|
||||
)
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
||||
"""Search for records with IDs starting with a specific prefix.
|
||||
|
||||
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
# Ti handles persistence automatically
|
||||
pass
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
"""Drop the storage"""
|
||||
try:
|
||||
drop_sql = """
|
||||
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
|
||||
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
|
||||
"""
|
||||
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
||||
return {"status": "success", "message": "graph data dropped"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Delete a node and all its related edges
|
||||
|
||||
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
|
||||
FROM LIGHTRAG_DOC_CHUNKS
|
||||
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
||||
""",
|
||||
# Drop tables
|
||||
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
|
||||
}
|
||||
|
@@ -13,7 +13,6 @@ import pandas as pd
|
||||
|
||||
|
||||
from lightrag.kg import (
|
||||
STORAGE_ENV_REQUIREMENTS,
|
||||
STORAGES,
|
||||
verify_storage_implementation,
|
||||
)
|
||||
@@ -230,6 +229,7 @@ class LightRAG:
|
||||
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional parameters for vector database storage."""
|
||||
|
||||
# TODO:deprecated, remove in the future, use WORKSPACE instead
|
||||
namespace_prefix: str = field(default="")
|
||||
"""Prefix for namespacing stored data across different environments."""
|
||||
|
||||
@@ -510,36 +510,22 @@ class LightRAG:
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 3,
|
||||
min_degree: int = 0,
|
||||
inclusive: bool = False,
|
||||
max_nodes: int = 1000,
|
||||
) -> KnowledgeGraph:
|
||||
"""Get knowledge graph for a given label
|
||||
|
||||
Args:
|
||||
node_label (str): Label to get knowledge graph for
|
||||
max_depth (int): Maximum depth of graph
|
||||
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
|
||||
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
|
||||
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Knowledge graph containing nodes and edges
|
||||
"""
|
||||
# get params supported by get_knowledge_graph of specified storage
|
||||
import inspect
|
||||
|
||||
storage_params = inspect.signature(
|
||||
self.chunk_entity_relation_graph.get_knowledge_graph
|
||||
).parameters
|
||||
|
||||
kwargs = {"node_label": node_label, "max_depth": max_depth}
|
||||
|
||||
if "min_degree" in storage_params and min_degree > 0:
|
||||
kwargs["min_degree"] = min_degree
|
||||
|
||||
if "inclusive" in storage_params:
|
||||
kwargs["inclusive"] = inclusive
|
||||
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
|
||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||
node_label, max_depth, max_nodes
|
||||
)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
import_path = STORAGES[storage_name]
|
||||
@@ -1449,6 +1435,7 @@ class LightRAG:
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def adelete_by_entity(self, entity_name: str) -> None:
|
||||
try:
|
||||
await self.entities_vdb.delete_entity(entity_name)
|
||||
@@ -1486,6 +1473,7 @@ class LightRAG:
|
||||
self.adelete_by_relation(source_entity, target_entity)
|
||||
)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
||||
"""Asynchronously delete a relation between two entities.
|
||||
|
||||
@@ -1494,6 +1482,7 @@ class LightRAG:
|
||||
target_entity: Name of the target entity
|
||||
"""
|
||||
try:
|
||||
# TODO: check if has_edge function works on reverse relation
|
||||
# Check if the relation exists
|
||||
edge_exists = await self.chunk_entity_relation_graph.has_edge(
|
||||
source_entity, target_entity
|
||||
@@ -1554,6 +1543,7 @@ class LightRAG:
|
||||
"""
|
||||
return await self.doc_status.get_docs_by_status(status)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
||||
"""Delete a document and all its related data
|
||||
|
||||
@@ -1586,6 +1576,8 @@ class LightRAG:
|
||||
chunk_ids = set(related_chunks.keys())
|
||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||
|
||||
# TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
|
||||
|
||||
# 3. Before deleting, check the related entities and relationships for these chunks
|
||||
for chunk_id in chunk_ids:
|
||||
# Check entities
|
||||
@@ -1857,24 +1849,6 @@ class LightRAG:
|
||||
|
||||
return result
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
||||
"""Clear cache data from the LLM response cache storage.
|
||||
|
||||
@@ -1906,12 +1880,18 @@ class LightRAG:
|
||||
try:
|
||||
# Reset the cache storage for specified mode
|
||||
if modes:
|
||||
await self.llm_response_cache.delete(modes)
|
||||
success = await self.llm_response_cache.drop_cache_by_modes(modes)
|
||||
if success:
|
||||
logger.info(f"Cleared cache for modes: {modes}")
|
||||
else:
|
||||
logger.warning(f"Failed to clear cache for modes: {modes}")
|
||||
else:
|
||||
# Clear all modes
|
||||
await self.llm_response_cache.delete(valid_modes)
|
||||
success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
|
||||
if success:
|
||||
logger.info("Cleared all cache")
|
||||
else:
|
||||
logger.warning("Failed to clear all cache")
|
||||
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
|
||||
@@ -1922,6 +1902,7 @@ class LightRAG:
|
||||
"""Synchronous version of aclear_cache."""
|
||||
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def aedit_entity(
|
||||
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
||||
) -> dict[str, Any]:
|
||||
@@ -2134,6 +2115,7 @@ class LightRAG:
|
||||
]
|
||||
)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def aedit_relation(
|
||||
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
@@ -2448,6 +2430,7 @@ class LightRAG:
|
||||
self.acreate_relation(source_entity, target_entity, relation_data)
|
||||
)
|
||||
|
||||
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
||||
async def amerge_entities(
|
||||
self,
|
||||
source_entities: list[str],
|
||||
|
@@ -44,6 +44,47 @@ class InvalidResponseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def create_openai_async_client(
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
client_configs: dict[str, Any] = None,
|
||||
) -> AsyncOpenAI:
|
||||
"""Create an AsyncOpenAI client with the given configuration.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
||||
Returns:
|
||||
An AsyncOpenAI client instance.
|
||||
"""
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if client_configs is None:
|
||||
client_configs = {}
|
||||
|
||||
# Create a merged config dict with precedence: explicit params > client_configs > defaults
|
||||
merged_configs = {
|
||||
**client_configs,
|
||||
"default_headers": default_headers,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if base_url is not None:
|
||||
merged_configs["base_url"] = base_url
|
||||
|
||||
return AsyncOpenAI(**merged_configs)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
@@ -61,29 +102,52 @@ async def openai_complete_if_cache(
|
||||
token_tracker: Any | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Complete a prompt using OpenAI's API with caching support.
|
||||
|
||||
Args:
|
||||
model: The OpenAI model to use.
|
||||
prompt: The prompt to complete.
|
||||
system_prompt: Optional system prompt to include.
|
||||
history_messages: Optional list of previous messages in the conversation.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
||||
Special kwargs:
|
||||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||
These will be passed to the client constructor but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
||||
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
||||
|
||||
Returns:
|
||||
The completed text or an async iterator of text chunks if streaming.
|
||||
|
||||
Raises:
|
||||
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
||||
APIConnectionError: If there is a connection error with the OpenAI API.
|
||||
RateLimitError: If the OpenAI API rate limit is exceeded.
|
||||
APITimeoutError: If the OpenAI API request times out.
|
||||
"""
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
||||
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
||||
logging.getLogger("openai").setLevel(logging.INFO)
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||
if base_url is None
|
||||
else AsyncOpenAI(
|
||||
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||
)
|
||||
# Extract client configuration options
|
||||
client_configs = kwargs.pop("openai_client_configs", {})
|
||||
|
||||
# Create the OpenAI client
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key, base_url=base_url, client_configs=client_configs
|
||||
)
|
||||
|
||||
# Remove special kwargs that shouldn't be passed to OpenAI
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
|
||||
# Prepare messages
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
@@ -272,21 +336,32 @@ async def openai_embed(
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
client_configs: dict[str, Any] = None,
|
||||
) -> np.ndarray:
|
||||
if not api_key:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
"""Generate embeddings for a list of texts using OpenAI's API.
|
||||
|
||||
default_headers = {
|
||||
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
openai_async_client = (
|
||||
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
|
||||
if base_url is None
|
||||
else AsyncOpenAI(
|
||||
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||
)
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The OpenAI embedding model to use.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If there is a connection error with the OpenAI API.
|
||||
RateLimitError: If the OpenAI API rate limit is exceeded.
|
||||
APITimeoutError: If the OpenAI API request times out.
|
||||
"""
|
||||
# Create the OpenAI client
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key, base_url=base_url, client_configs=client_configs
|
||||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
|
@@ -26,7 +26,6 @@ from .utils import (
|
||||
CacheData,
|
||||
statistic_data,
|
||||
get_conversation_turns,
|
||||
verbose_debug,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -442,6 +441,13 @@ async def extract_entities(
|
||||
|
||||
processed_chunks = 0
|
||||
total_chunks = len(ordered_chunks)
|
||||
total_entities_count = 0
|
||||
total_relations_count = 0
|
||||
|
||||
# Get lock manager from shared storage
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
async def _user_llm_func_with_cache(
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
@@ -540,7 +546,7 @@ async def extract_entities(
|
||||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||||
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||||
"""
|
||||
nonlocal processed_chunks
|
||||
nonlocal processed_chunks, total_entities_count, total_relations_count
|
||||
chunk_key = chunk_key_dp[0]
|
||||
chunk_dp = chunk_key_dp[1]
|
||||
content = chunk_dp["content"]
|
||||
@@ -598,77 +604,34 @@ async def extract_entities(
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||
results = await asyncio.gather(*tasks)
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
chunk_entities_data = []
|
||||
chunk_relationships_data = []
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for m_nodes, m_edges in results:
|
||||
for k, v in m_nodes.items():
|
||||
maybe_nodes[k].extend(v)
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
# Ensure that nodes and edges are merged and upserted atomically
|
||||
async with graph_db_lock:
|
||||
all_entities_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
# Process and update entities
|
||||
for entity_name, entities in maybe_nodes.items():
|
||||
entity_data = await _merge_nodes_then_upsert(
|
||||
entity_name, entities, knowledge_graph_inst, global_config
|
||||
)
|
||||
chunk_entities_data.append(entity_data)
|
||||
|
||||
all_relationships_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_edges_then_upsert(
|
||||
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||
)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
# Process and update relationships
|
||||
for edge_key, edges in maybe_edges.items():
|
||||
# Ensure edge direction consistency
|
||||
sorted_edge_key = tuple(sorted(edge_key))
|
||||
edge_data = await _merge_edges_then_upsert(
|
||||
sorted_edge_key[0],
|
||||
sorted_edge_key[1],
|
||||
edges,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
)
|
||||
chunk_relationships_data.append(edge_data)
|
||||
|
||||
if not (all_entities_data or all_relationships_data):
|
||||
log_message = "Didn't extract any entities and relationships."
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return
|
||||
|
||||
if not all_entities_data:
|
||||
log_message = "Didn't extract any entities"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if not all_relationships_data:
|
||||
log_message = "Didn't extract any relationships"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
verbose_debug(
|
||||
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
||||
)
|
||||
verbose_debug(f"New relationships:{all_relationships_data}")
|
||||
|
||||
if entity_vdb is not None:
|
||||
# Update vector database (within the same lock to ensure atomicity)
|
||||
if entity_vdb is not None and chunk_entities_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"entity_name": dp["entity_name"],
|
||||
@@ -677,11 +640,11 @@ async def extract_entities(
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in all_entities_data
|
||||
for dp in chunk_entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
|
||||
if relationships_vdb is not None:
|
||||
if relationships_vdb is not None and chunk_relationships_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
@@ -691,10 +654,25 @@ async def extract_entities(
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
for dp in chunk_relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
# Update counters
|
||||
total_entities_count += len(chunk_entities_data)
|
||||
total_relations_count += len(chunk_relationships_data)
|
||||
|
||||
# Handle all chunks in parallel
|
||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query: str,
|
||||
@@ -720,8 +698,7 @@ async def kg_query(
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
@@ -817,6 +794,38 @@ async def kg_query(
|
||||
return response
|
||||
|
||||
|
||||
async def get_keywords_from_query(
|
||||
query: str,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Retrieves high-level and low-level keywords for RAG operations.
|
||||
|
||||
This function checks if keywords are already provided in query parameters,
|
||||
and if not, extracts them from the query text using LLM.
|
||||
|
||||
Args:
|
||||
query: The user's query text
|
||||
query_param: Query parameters that may contain pre-defined keywords
|
||||
global_config: Global configuration dictionary
|
||||
hashing_kv: Optional key-value storage for caching results
|
||||
|
||||
Returns:
|
||||
A tuple containing (high_level_keywords, low_level_keywords)
|
||||
"""
|
||||
# Check if pre-defined keywords are already provided
|
||||
if query_param.hl_keywords or query_param.ll_keywords:
|
||||
return query_param.hl_keywords, query_param.ll_keywords
|
||||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
return hl_keywords, ll_keywords
|
||||
|
||||
|
||||
async def extract_keywords_only(
|
||||
text: str,
|
||||
param: QueryParam,
|
||||
@@ -957,8 +966,7 @@ async def mix_kg_vector_query(
|
||||
# 2. Execute knowledge graph and vector searches in parallel
|
||||
async def get_kg_context():
|
||||
try:
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
@@ -1339,7 +1347,9 @@ async def _get_node_data(
|
||||
|
||||
text_units_section_list = [["id", "content", "file_path"]]
|
||||
for i, t in enumerate(use_text_units):
|
||||
text_units_section_list.append([i, t["content"], t["file_path"]])
|
||||
text_units_section_list.append(
|
||||
[i, t["content"], t.get("file_path", "unknown_source")]
|
||||
)
|
||||
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||
return entities_context, relations_context, text_units_context
|
||||
|
||||
@@ -2043,16 +2053,13 @@ async def query_with_keywords(
|
||||
Query response or async iterator
|
||||
"""
|
||||
# Extract keywords
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
text=query,
|
||||
param=param,
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query=query,
|
||||
query_param=param,
|
||||
global_config=global_config,
|
||||
hashing_kv=hashing_kv,
|
||||
)
|
||||
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# Create a new string with the prompt and the keywords
|
||||
ll_keywords_str = ", ".join(ll_keywords)
|
||||
hl_keywords_str = ", ".join(hl_keywords)
|
||||
|
@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
|
||||
class KnowledgeGraph(BaseModel):
|
||||
nodes: list[KnowledgeGraphNode] = []
|
||||
edges: list[KnowledgeGraphEdge] = []
|
||||
is_truncated: bool = False
|
||||
|
@@ -3,12 +3,13 @@ import ThemeProvider from '@/components/ThemeProvider'
|
||||
import TabVisibilityProvider from '@/contexts/TabVisibilityProvider'
|
||||
import ApiKeyAlert from '@/components/ApiKeyAlert'
|
||||
import StatusIndicator from '@/components/status/StatusIndicator'
|
||||
import { healthCheckInterval } from '@/lib/constants'
|
||||
import { healthCheckInterval, SiteInfo, webuiPrefix } from '@/lib/constants'
|
||||
import { useBackendState, useAuthStore } from '@/stores/state'
|
||||
import { useSettingsStore } from '@/stores/settings'
|
||||
import { getAuthStatus } from '@/api/lightrag'
|
||||
import SiteHeader from '@/features/SiteHeader'
|
||||
import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag'
|
||||
import { ZapIcon } from 'lucide-react'
|
||||
|
||||
import GraphViewer from '@/features/GraphViewer'
|
||||
import DocumentManager from '@/features/DocumentManager'
|
||||
@@ -22,6 +23,7 @@ function App() {
|
||||
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
||||
const currentTab = useSettingsStore.use.currentTab()
|
||||
const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false)
|
||||
const [initializing, setInitializing] = useState(true) // Add initializing state
|
||||
const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode
|
||||
|
||||
const handleApiKeyAlertOpenChange = useCallback((open: boolean) => {
|
||||
@@ -55,29 +57,48 @@ function App() {
|
||||
|
||||
// Check if version info was already obtained in login page
|
||||
const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true';
|
||||
if (versionCheckedFromLogin) return;
|
||||
if (versionCheckedFromLogin) {
|
||||
setInitializing(false); // Skip initialization if already checked
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setInitializing(true); // Start initialization
|
||||
|
||||
// Get version info
|
||||
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
||||
if (!token) return;
|
||||
|
||||
try {
|
||||
const status = await getAuthStatus();
|
||||
if (status.core_version || status.api_version) {
|
||||
|
||||
// If auth is not configured and a new token is returned, use the new token
|
||||
if (!status.auth_configured && status.access_token) {
|
||||
useAuthStore.getState().login(
|
||||
status.access_token, // Use the new token
|
||||
true, // Guest mode
|
||||
status.core_version,
|
||||
status.api_version,
|
||||
status.webui_title || null,
|
||||
status.webui_description || null
|
||||
);
|
||||
} else if (token && (status.core_version || status.api_version || status.webui_title || status.webui_description)) {
|
||||
// Otherwise use the old token (if it exists)
|
||||
const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode;
|
||||
// Update version info while maintaining login state
|
||||
useAuthStore.getState().login(
|
||||
token,
|
||||
isGuestMode,
|
||||
status.core_version,
|
||||
status.api_version
|
||||
status.api_version,
|
||||
status.webui_title || null,
|
||||
status.webui_description || null
|
||||
);
|
||||
}
|
||||
|
||||
// Set flag to indicate version info has been checked
|
||||
sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to get version info:', error);
|
||||
} finally {
|
||||
// Ensure initializing is set to false even if there's an error
|
||||
setInitializing(false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -101,6 +122,37 @@ function App() {
|
||||
return (
|
||||
<ThemeProvider>
|
||||
<TabVisibilityProvider>
|
||||
{initializing ? (
|
||||
// Loading state while initializing with simplified header
|
||||
<div className="flex h-screen w-screen flex-col">
|
||||
{/* Simplified header during initialization - matches SiteHeader structure */}
|
||||
<header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
|
||||
<div className="min-w-[200px] w-auto flex items-center">
|
||||
<a href={webuiPrefix} className="flex items-center gap-2">
|
||||
<ZapIcon className="size-4 text-emerald-400" aria-hidden="true" />
|
||||
<span className="font-bold md:inline-block">{SiteInfo.name}</span>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{/* Empty middle section to maintain layout */}
|
||||
<div className="flex h-10 flex-1 items-center justify-center">
|
||||
</div>
|
||||
|
||||
{/* Empty right section to maintain layout */}
|
||||
<nav className="w-[200px] flex items-center justify-end">
|
||||
</nav>
|
||||
</header>
|
||||
|
||||
{/* Loading indicator in content area */}
|
||||
<div className="flex flex-1 items-center justify-center">
|
||||
<div className="text-center">
|
||||
<div className="mb-2 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"></div>
|
||||
<p>Initializing...</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
// Main content after initialization
|
||||
<main className="flex h-screen w-screen overflow-hidden">
|
||||
<Tabs
|
||||
defaultValue={currentTab}
|
||||
@@ -126,6 +178,7 @@ function App() {
|
||||
{enableHealthCheck && <StatusIndicator />}
|
||||
<ApiKeyAlert open={apiKeyAlertOpen} onOpenChange={handleApiKeyAlertOpenChange} />
|
||||
</main>
|
||||
)}
|
||||
</TabVisibilityProvider>
|
||||
</ThemeProvider>
|
||||
)
|
||||
|
@@ -80,7 +80,12 @@ const AppRouter = () => {
|
||||
<ThemeProvider>
|
||||
<Router>
|
||||
<AppContent />
|
||||
<Toaster position="bottom-center" />
|
||||
<Toaster
|
||||
position="bottom-center"
|
||||
theme="system"
|
||||
closeButton
|
||||
richColors
|
||||
/>
|
||||
</Router>
|
||||
</ThemeProvider>
|
||||
)
|
||||
|
@@ -46,6 +46,8 @@ export type LightragStatus = {
|
||||
api_version?: string
|
||||
auth_mode?: 'enabled' | 'disabled'
|
||||
pipeline_busy: boolean
|
||||
webui_title?: string
|
||||
webui_description?: string
|
||||
}
|
||||
|
||||
export type LightragDocumentsScanProgress = {
|
||||
@@ -140,6 +142,8 @@ export type AuthStatusResponse = {
|
||||
message?: string
|
||||
core_version?: string
|
||||
api_version?: string
|
||||
webui_title?: string
|
||||
webui_description?: string
|
||||
}
|
||||
|
||||
export type PipelineStatusResponse = {
|
||||
@@ -163,6 +167,8 @@ export type LoginResponse = {
|
||||
message?: string // Optional message
|
||||
core_version?: string
|
||||
api_version?: string
|
||||
webui_title?: string
|
||||
webui_description?: string
|
||||
}
|
||||
|
||||
export const InvalidApiKeyError = 'Invalid API Key'
|
||||
@@ -221,9 +227,9 @@ axiosInstance.interceptors.response.use(
|
||||
export const queryGraphs = async (
|
||||
label: string,
|
||||
maxDepth: number,
|
||||
minDegree: number
|
||||
maxNodes: number
|
||||
): Promise<LightragGraphType> => {
|
||||
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
|
||||
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&max_nodes=${maxNodes}`)
|
||||
return response.data
|
||||
}
|
||||
|
||||
@@ -382,6 +388,14 @@ export const clearDocuments = async (): Promise<DocActionResponse> => {
|
||||
return response.data
|
||||
}
|
||||
|
||||
export const clearCache = async (modes?: string[]): Promise<{
|
||||
status: 'success' | 'fail'
|
||||
message: string
|
||||
}> => {
|
||||
const response = await axiosInstance.post('/documents/clear_cache', { modes })
|
||||
return response.data
|
||||
}
|
||||
|
||||
export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
|
||||
try {
|
||||
// Add a timeout to the request to prevent hanging
|
||||
|
@@ -1,4 +1,4 @@
|
||||
import { useState, useCallback } from 'react'
|
||||
import { useState, useCallback, useEffect } from 'react'
|
||||
import Button from '@/components/ui/Button'
|
||||
import {
|
||||
Dialog,
|
||||
@@ -6,32 +6,88 @@ import {
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger
|
||||
DialogTrigger,
|
||||
DialogFooter
|
||||
} from '@/components/ui/Dialog'
|
||||
import Input from '@/components/ui/Input'
|
||||
import Checkbox from '@/components/ui/Checkbox'
|
||||
import { toast } from 'sonner'
|
||||
import { errorMessage } from '@/lib/utils'
|
||||
import { clearDocuments } from '@/api/lightrag'
|
||||
import { clearDocuments, clearCache } from '@/api/lightrag'
|
||||
|
||||
import { EraserIcon } from 'lucide-react'
|
||||
import { EraserIcon, AlertTriangleIcon } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export default function ClearDocumentsDialog() {
|
||||
// 简单的Label组件
|
||||
const Label = ({
|
||||
htmlFor,
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.LabelHTMLAttributes<HTMLLabelElement>) => (
|
||||
<label
|
||||
htmlFor={htmlFor}
|
||||
className={className}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</label>
|
||||
)
|
||||
|
||||
interface ClearDocumentsDialogProps {
|
||||
onDocumentsCleared?: () => Promise<void>
|
||||
}
|
||||
|
||||
export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocumentsDialogProps) {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const [confirmText, setConfirmText] = useState('')
|
||||
const [clearCacheOption, setClearCacheOption] = useState(false)
|
||||
const isConfirmEnabled = confirmText.toLowerCase() === 'yes'
|
||||
|
||||
// 重置状态当对话框关闭时
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
setConfirmText('')
|
||||
setClearCacheOption(false)
|
||||
}
|
||||
}, [open])
|
||||
|
||||
const handleClear = useCallback(async () => {
|
||||
if (!isConfirmEnabled) return
|
||||
|
||||
try {
|
||||
const result = await clearDocuments()
|
||||
if (result.status === 'success') {
|
||||
toast.success(t('documentPanel.clearDocuments.success'))
|
||||
setOpen(false)
|
||||
} else {
|
||||
|
||||
if (result.status !== 'success') {
|
||||
toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
|
||||
setConfirmText('')
|
||||
return
|
||||
}
|
||||
|
||||
toast.success(t('documentPanel.clearDocuments.success'))
|
||||
|
||||
if (clearCacheOption) {
|
||||
try {
|
||||
await clearCache()
|
||||
toast.success(t('documentPanel.clearDocuments.cacheCleared'))
|
||||
} catch (cacheErr) {
|
||||
toast.error(t('documentPanel.clearDocuments.cacheClearFailed', { error: errorMessage(cacheErr) }))
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh document list if provided
|
||||
if (onDocumentsCleared) {
|
||||
onDocumentsCleared().catch(console.error)
|
||||
}
|
||||
|
||||
// 所有操作成功后关闭对话框
|
||||
setOpen(false)
|
||||
} catch (err) {
|
||||
toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
|
||||
setConfirmText('')
|
||||
}
|
||||
}, [setOpen, t])
|
||||
}, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared])
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={setOpen}>
|
||||
@@ -42,12 +98,60 @@ export default function ClearDocumentsDialog() {
|
||||
</DialogTrigger>
|
||||
<DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('documentPanel.clearDocuments.title')}</DialogTitle>
|
||||
<DialogDescription>{t('documentPanel.clearDocuments.confirm')}</DialogDescription>
|
||||
<DialogTitle className="flex items-center gap-2 text-red-500 dark:text-red-400 font-bold">
|
||||
<AlertTriangleIcon className="h-5 w-5" />
|
||||
{t('documentPanel.clearDocuments.title')}
|
||||
</DialogTitle>
|
||||
<DialogDescription className="pt-2">
|
||||
{t('documentPanel.clearDocuments.description')}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<Button variant="destructive" onClick={handleClear}>
|
||||
|
||||
<div className="text-red-500 dark:text-red-400 font-semibold mb-4">
|
||||
{t('documentPanel.clearDocuments.warning')}
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
{t('documentPanel.clearDocuments.confirm')}
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="confirm-text" className="text-sm font-medium">
|
||||
{t('documentPanel.clearDocuments.confirmPrompt')}
|
||||
</Label>
|
||||
<Input
|
||||
id="confirm-text"
|
||||
value={confirmText}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setConfirmText(e.target.value)}
|
||||
placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')}
|
||||
className="w-full"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
id="clear-cache"
|
||||
checked={clearCacheOption}
|
||||
onCheckedChange={(checked: boolean | 'indeterminate') => setClearCacheOption(checked === true)}
|
||||
/>
|
||||
<Label htmlFor="clear-cache" className="text-sm font-medium cursor-pointer">
|
||||
{t('documentPanel.clearDocuments.clearCache')}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => setOpen(false)}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={handleClear}
|
||||
disabled={!isConfirmEnabled}
|
||||
>
|
||||
{t('documentPanel.clearDocuments.confirmButton')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
|
@@ -17,7 +17,11 @@ import { uploadDocument } from '@/api/lightrag'
|
||||
import { UploadIcon } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export default function UploadDocumentsDialog() {
|
||||
interface UploadDocumentsDialogProps {
|
||||
onDocumentsUploaded?: () => Promise<void>
|
||||
}
|
||||
|
||||
export default function UploadDocumentsDialog({ onDocumentsUploaded }: UploadDocumentsDialogProps) {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const [isUploading, setIsUploading] = useState(false)
|
||||
@@ -55,6 +59,7 @@ export default function UploadDocumentsDialog() {
|
||||
const handleDocumentsUpload = useCallback(
|
||||
async (filesToUpload: File[]) => {
|
||||
setIsUploading(true)
|
||||
let hasSuccessfulUpload = false
|
||||
|
||||
// Only clear errors for files that are being uploaded, keep errors for rejected files
|
||||
setFileErrors(prev => {
|
||||
@@ -101,6 +106,9 @@ export default function UploadDocumentsDialog() {
|
||||
...prev,
|
||||
[file.name]: result.message
|
||||
}))
|
||||
} else {
|
||||
// Mark that we had at least one successful upload
|
||||
hasSuccessfulUpload = true
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(`Upload failed for ${file.name}:`, err)
|
||||
@@ -142,6 +150,16 @@ export default function UploadDocumentsDialog() {
|
||||
} else {
|
||||
toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId })
|
||||
}
|
||||
|
||||
// Only update if at least one file was uploaded successfully
|
||||
if (hasSuccessfulUpload) {
|
||||
// Refresh document list
|
||||
if (onDocumentsUploaded) {
|
||||
onDocumentsUploaded().catch(err => {
|
||||
console.error('Error refreshing documents:', err)
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Unexpected error during upload:', err)
|
||||
toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId })
|
||||
@@ -149,7 +167,7 @@ export default function UploadDocumentsDialog() {
|
||||
setIsUploading(false)
|
||||
}
|
||||
},
|
||||
[setIsUploading, setProgresses, setFileErrors, t]
|
||||
[setIsUploading, setProgresses, setFileErrors, t, onDocumentsUploaded]
|
||||
)
|
||||
|
||||
return (
|
||||
|
@@ -36,6 +36,8 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
|
||||
const enableEdgeEvents = useSettingsStore.use.enableEdgeEvents()
|
||||
const renderEdgeLabels = useSettingsStore.use.showEdgeLabel()
|
||||
const renderLabels = useSettingsStore.use.showNodeLabel()
|
||||
const minEdgeSize = useSettingsStore.use.minEdgeSize()
|
||||
const maxEdgeSize = useSettingsStore.use.maxEdgeSize()
|
||||
const selectedNode = useGraphStore.use.selectedNode()
|
||||
const focusedNode = useGraphStore.use.focusedNode()
|
||||
const selectedEdge = useGraphStore.use.selectedEdge()
|
||||
@@ -136,6 +138,51 @@ const GraphControl = ({ disableHoverEffect }: { disableHoverEffect?: boolean })
|
||||
registerEvents(events)
|
||||
}, [registerEvents, enableEdgeEvents])
|
||||
|
||||
/**
|
||||
* When edge size settings change, recalculate edge sizes and refresh the sigma instance
|
||||
* to ensure changes take effect immediately
|
||||
*/
|
||||
useEffect(() => {
|
||||
if (sigma && sigmaGraph) {
|
||||
// Get the graph from sigma
|
||||
const graph = sigma.getGraph()
|
||||
|
||||
// Find min and max weight values
|
||||
let minWeight = Number.MAX_SAFE_INTEGER
|
||||
let maxWeight = 0
|
||||
|
||||
graph.forEachEdge(edge => {
|
||||
// Get original weight (before scaling)
|
||||
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
|
||||
if (typeof weight === 'number') {
|
||||
minWeight = Math.min(minWeight, weight)
|
||||
maxWeight = Math.max(maxWeight, weight)
|
||||
}
|
||||
})
|
||||
|
||||
// Scale edge sizes based on weight range and current min/max edge size settings
|
||||
const weightRange = maxWeight - minWeight
|
||||
if (weightRange > 0) {
|
||||
const sizeScale = maxEdgeSize - minEdgeSize
|
||||
graph.forEachEdge(edge => {
|
||||
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
|
||||
if (typeof weight === 'number') {
|
||||
const scaledSize = minEdgeSize + sizeScale * Math.pow((weight - minWeight) / weightRange, 0.5)
|
||||
graph.setEdgeAttribute(edge, 'size', scaledSize)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
// If all weights are the same, use default size
|
||||
graph.forEachEdge(edge => {
|
||||
graph.setEdgeAttribute(edge, 'size', minEdgeSize)
|
||||
})
|
||||
}
|
||||
|
||||
// Refresh the sigma instance to apply changes
|
||||
sigma.refresh()
|
||||
}
|
||||
}, [sigma, sigmaGraph, minEdgeSize, maxEdgeSize])
|
||||
|
||||
/**
|
||||
* When component mount or hovered node change
|
||||
* => Setting the sigma reducers
|
||||
|
@@ -1,4 +1,4 @@
|
||||
import { useCallback } from 'react'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import { AsyncSelect } from '@/components/ui/AsyncSelect'
|
||||
import { useSettingsStore } from '@/stores/settings'
|
||||
import { useGraphStore } from '@/stores/graph'
|
||||
@@ -56,6 +56,23 @@ const GraphLabels = () => {
|
||||
[getSearchEngine]
|
||||
)
|
||||
|
||||
// Validate if current queryLabel exists in allDatabaseLabels
|
||||
useEffect(() => {
|
||||
// Only update label when all conditions are met:
|
||||
// 1. allDatabaseLabels is loaded (length > 1, as it has at least '*' by default)
|
||||
// 2. Current label is not the default '*'
|
||||
// 3. Current label doesn't exist in allDatabaseLabels
|
||||
if (
|
||||
allDatabaseLabels.length > 1 &&
|
||||
label &&
|
||||
label !== '*' &&
|
||||
!allDatabaseLabels.includes(label)
|
||||
) {
|
||||
console.log(`Label "${label}" not found in available labels, resetting to default`);
|
||||
useSettingsStore.getState().setQueryLabel('*');
|
||||
}
|
||||
}, [allDatabaseLabels, label]);
|
||||
|
||||
const handleRefresh = useCallback(() => {
|
||||
// Reset fetch status flags
|
||||
useGraphStore.getState().setLabelsFetchAttempted(false)
|
||||
|
41
lightrag_webui/src/components/graph/Legend.tsx
Normal file
41
lightrag_webui/src/components/graph/Legend.tsx
Normal file
@@ -0,0 +1,41 @@
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useGraphStore } from '@/stores/graph'
|
||||
import { Card } from '@/components/ui/Card'
|
||||
import { ScrollArea } from '@/components/ui/ScrollArea'
|
||||
|
||||
interface LegendProps {
|
||||
className?: string
|
||||
}
|
||||
|
||||
const Legend: React.FC<LegendProps> = ({ className }) => {
|
||||
const { t } = useTranslation()
|
||||
const typeColorMap = useGraphStore.use.typeColorMap()
|
||||
|
||||
if (!typeColorMap || typeColorMap.size === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<Card className={`p-2 max-w-xs ${className}`}>
|
||||
<h3 className="text-sm font-medium mb-2">{t('graphPanel.legend')}</h3>
|
||||
<ScrollArea className="max-h-40">
|
||||
<div className="flex flex-col gap-1">
|
||||
{Array.from(typeColorMap.entries()).map(([type, color]) => (
|
||||
<div key={type} className="flex items-center gap-2">
|
||||
<div
|
||||
className="w-4 h-4 rounded-full"
|
||||
style={{ backgroundColor: color }}
|
||||
/>
|
||||
<span className="text-xs truncate" title={type}>
|
||||
{type}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</Card>
|
||||
)
|
||||
}
|
||||
|
||||
export default Legend
|
32
lightrag_webui/src/components/graph/LegendButton.tsx
Normal file
32
lightrag_webui/src/components/graph/LegendButton.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { useCallback } from 'react'
|
||||
import { BookOpenIcon } from 'lucide-react'
|
||||
import Button from '@/components/ui/Button'
|
||||
import { controlButtonVariant } from '@/lib/constants'
|
||||
import { useSettingsStore } from '@/stores/settings'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
/**
|
||||
* Component that toggles legend visibility.
|
||||
*/
|
||||
const LegendButton = () => {
|
||||
const { t } = useTranslation()
|
||||
const showLegend = useSettingsStore.use.showLegend()
|
||||
const setShowLegend = useSettingsStore.use.setShowLegend()
|
||||
|
||||
const toggleLegend = useCallback(() => {
|
||||
setShowLegend(!showLegend)
|
||||
}, [showLegend, setShowLegend])
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant={controlButtonVariant}
|
||||
onClick={toggleLegend}
|
||||
tooltip={t('graphPanel.sideBar.legendControl.toggleLegend')}
|
||||
size="icon"
|
||||
>
|
||||
<BookOpenIcon />
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
|
||||
export default LegendButton
|
@@ -8,7 +8,7 @@ import Input from '@/components/ui/Input'
|
||||
import { controlButtonVariant } from '@/lib/constants'
|
||||
import { useSettingsStore } from '@/stores/settings'
|
||||
|
||||
import { SettingsIcon } from 'lucide-react'
|
||||
import { SettingsIcon, Undo2 } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
/**
|
||||
@@ -44,14 +44,17 @@ const LabeledNumberInput = ({
|
||||
onEditFinished,
|
||||
label,
|
||||
min,
|
||||
max
|
||||
max,
|
||||
defaultValue
|
||||
}: {
|
||||
value: number
|
||||
onEditFinished: (value: number) => void
|
||||
label: string
|
||||
min: number
|
||||
max?: number
|
||||
defaultValue?: number
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const [currentValue, setCurrentValue] = useState<number | null>(value)
|
||||
|
||||
const onValueChange = useCallback(
|
||||
@@ -81,6 +84,13 @@ const LabeledNumberInput = ({
|
||||
}
|
||||
}, [value, currentValue, onEditFinished])
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
if (defaultValue !== undefined && value !== defaultValue) {
|
||||
setCurrentValue(defaultValue)
|
||||
onEditFinished(defaultValue)
|
||||
}
|
||||
}, [defaultValue, value, onEditFinished])
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<label
|
||||
@@ -89,6 +99,7 @@ const LabeledNumberInput = ({
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
<div className="flex items-center gap-1">
|
||||
<Input
|
||||
type="number"
|
||||
value={currentValue === null ? '' : currentValue}
|
||||
@@ -103,6 +114,19 @@ const LabeledNumberInput = ({
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{defaultValue !== undefined && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-6 w-6 flex-shrink-0 hover:bg-muted text-muted-foreground hover:text-foreground"
|
||||
onClick={handleReset}
|
||||
type="button"
|
||||
title={t('graphPanel.sideBar.settings.resetToDefault')}
|
||||
>
|
||||
<Undo2 className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -120,8 +144,10 @@ export default function Settings() {
|
||||
const enableNodeDrag = useSettingsStore.use.enableNodeDrag()
|
||||
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
|
||||
const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
|
||||
const minEdgeSize = useSettingsStore.use.minEdgeSize()
|
||||
const maxEdgeSize = useSettingsStore.use.maxEdgeSize()
|
||||
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
|
||||
const graphMinDegree = useSettingsStore.use.graphMinDegree()
|
||||
const graphMaxNodes = useSettingsStore.use.graphMaxNodes()
|
||||
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
|
||||
|
||||
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
||||
@@ -180,15 +206,14 @@ export default function Settings() {
|
||||
}, 300)
|
||||
}, [])
|
||||
|
||||
const setGraphMinDegree = useCallback((degree: number) => {
|
||||
if (degree < 0) return
|
||||
useSettingsStore.setState({ graphMinDegree: degree })
|
||||
const setGraphMaxNodes = useCallback((nodes: number) => {
|
||||
if (nodes < 1 || nodes > 1000) return
|
||||
useSettingsStore.setState({ graphMaxNodes: nodes })
|
||||
const currentLabel = useSettingsStore.getState().queryLabel
|
||||
useSettingsStore.getState().setQueryLabel('')
|
||||
setTimeout(() => {
|
||||
useSettingsStore.getState().setQueryLabel(currentLabel)
|
||||
}, 300)
|
||||
|
||||
}, [])
|
||||
|
||||
const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
|
||||
@@ -269,24 +294,75 @@ export default function Settings() {
|
||||
label={t('graphPanel.sideBar.settings.edgeEvents')}
|
||||
/>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
<label className="text-sm leading-none font-medium peer-disabled:cursor-not-allowed peer-disabled:opacity-70">
|
||||
{t('graphPanel.sideBar.settings.edgeSizeRange')}
|
||||
</label>
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
type="number"
|
||||
value={minEdgeSize}
|
||||
onChange={(e) => {
|
||||
const newValue = Number(e.target.value);
|
||||
if (!isNaN(newValue) && newValue >= 1 && newValue <= maxEdgeSize) {
|
||||
useSettingsStore.setState({ minEdgeSize: newValue });
|
||||
}
|
||||
}}
|
||||
className="h-6 w-16 min-w-0 pr-1"
|
||||
min={1}
|
||||
max={Math.min(maxEdgeSize, 10)}
|
||||
/>
|
||||
<span>-</span>
|
||||
<div className="flex items-center gap-1">
|
||||
<Input
|
||||
type="number"
|
||||
value={maxEdgeSize}
|
||||
onChange={(e) => {
|
||||
const newValue = Number(e.target.value);
|
||||
if (!isNaN(newValue) && newValue >= minEdgeSize && newValue >= 1 && newValue <= 10) {
|
||||
useSettingsStore.setState({ maxEdgeSize: newValue });
|
||||
}
|
||||
}}
|
||||
className="h-6 w-16 min-w-0 pr-1"
|
||||
min={minEdgeSize}
|
||||
max={10}
|
||||
/>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-6 w-6 flex-shrink-0 hover:bg-muted text-muted-foreground hover:text-foreground"
|
||||
onClick={() => useSettingsStore.setState({ minEdgeSize: 1, maxEdgeSize: 5 })}
|
||||
type="button"
|
||||
title={t('graphPanel.sideBar.settings.resetToDefault')}
|
||||
>
|
||||
<Undo2 className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
<LabeledNumberInput
|
||||
label={t('graphPanel.sideBar.settings.maxQueryDepth')}
|
||||
min={1}
|
||||
value={graphQueryMaxDepth}
|
||||
defaultValue={3}
|
||||
onEditFinished={setGraphQueryMaxDepth}
|
||||
/>
|
||||
<LabeledNumberInput
|
||||
label={t('graphPanel.sideBar.settings.minDegree')}
|
||||
min={0}
|
||||
value={graphMinDegree}
|
||||
onEditFinished={setGraphMinDegree}
|
||||
label={t('graphPanel.sideBar.settings.maxNodes')}
|
||||
min={1}
|
||||
max={1000}
|
||||
value={graphMaxNodes}
|
||||
defaultValue={1000}
|
||||
onEditFinished={setGraphMaxNodes}
|
||||
/>
|
||||
<LabeledNumberInput
|
||||
label={t('graphPanel.sideBar.settings.maxLayoutIterations')}
|
||||
min={1}
|
||||
max={30}
|
||||
value={graphLayoutMaxIterations}
|
||||
defaultValue={15}
|
||||
onEditFinished={setGraphLayoutMaxIterations}
|
||||
/>
|
||||
<Separator />
|
||||
|
@@ -8,12 +8,12 @@ import { useTranslation } from 'react-i18next'
|
||||
const SettingsDisplay = () => {
|
||||
const { t } = useTranslation()
|
||||
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
|
||||
const graphMinDegree = useSettingsStore.use.graphMinDegree()
|
||||
const graphMaxNodes = useSettingsStore.use.graphMaxNodes()
|
||||
|
||||
return (
|
||||
<div className="absolute bottom-4 left-[calc(1rem+2.5rem)] flex items-center gap-2 text-xs text-gray-400">
|
||||
<div>{t('graphPanel.sideBar.settings.depth')}: {graphQueryMaxDepth}</div>
|
||||
<div>{t('graphPanel.sideBar.settings.degree')}: {graphMinDegree}</div>
|
||||
<div>{t('graphPanel.sideBar.settings.max')}: {graphMaxNodes}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
@@ -4,14 +4,14 @@ import { useTranslation } from 'react-i18next'
|
||||
const StatusCard = ({ status }: { status: LightragStatus | null }) => {
|
||||
const { t } = useTranslation()
|
||||
if (!status) {
|
||||
return <div className="text-muted-foreground text-sm">{t('graphPanel.statusCard.unavailable')}</div>
|
||||
return <div className="text-foreground text-xs">{t('graphPanel.statusCard.unavailable')}</div>
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-w-[300px] space-y-3 text-sm">
|
||||
<div className="min-w-[300px] space-y-2 text-xs">
|
||||
<div className="space-y-1">
|
||||
<h4 className="font-medium">{t('graphPanel.statusCard.storageInfo')}</h4>
|
||||
<div className="text-muted-foreground grid grid-cols-2 gap-1">
|
||||
<div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
|
||||
<span>{t('graphPanel.statusCard.workingDirectory')}:</span>
|
||||
<span className="truncate">{status.working_directory}</span>
|
||||
<span>{t('graphPanel.statusCard.inputDirectory')}:</span>
|
||||
@@ -21,7 +21,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {
|
||||
|
||||
<div className="space-y-1">
|
||||
<h4 className="font-medium">{t('graphPanel.statusCard.llmConfig')}</h4>
|
||||
<div className="text-muted-foreground grid grid-cols-2 gap-1">
|
||||
<div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
|
||||
<span>{t('graphPanel.statusCard.llmBinding')}:</span>
|
||||
<span>{status.configuration.llm_binding}</span>
|
||||
<span>{t('graphPanel.statusCard.llmBindingHost')}:</span>
|
||||
@@ -35,7 +35,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {
|
||||
|
||||
<div className="space-y-1">
|
||||
<h4 className="font-medium">{t('graphPanel.statusCard.embeddingConfig')}</h4>
|
||||
<div className="text-muted-foreground grid grid-cols-2 gap-1">
|
||||
<div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
|
||||
<span>{t('graphPanel.statusCard.embeddingBinding')}:</span>
|
||||
<span>{status.configuration.embedding_binding}</span>
|
||||
<span>{t('graphPanel.statusCard.embeddingBindingHost')}:</span>
|
||||
@@ -47,7 +47,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {
|
||||
|
||||
<div className="space-y-1">
|
||||
<h4 className="font-medium">{t('graphPanel.statusCard.storageConfig')}</h4>
|
||||
<div className="text-muted-foreground grid grid-cols-2 gap-1">
|
||||
<div className="text-foreground grid grid-cols-[120px_1fr] gap-1">
|
||||
<span>{t('graphPanel.statusCard.kvStorage')}:</span>
|
||||
<span>{status.configuration.kv_storage}</span>
|
||||
<span>{t('graphPanel.statusCard.docStatusStorage')}:</span>
|
||||
|
32
lightrag_webui/src/components/status/StatusDialog.tsx
Normal file
32
lightrag_webui/src/components/status/StatusDialog.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { LightragStatus } from '@/api/lightrag'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/Dialog'
|
||||
import StatusCard from './StatusCard'
|
||||
|
||||
interface StatusDialogProps {
|
||||
open: boolean
|
||||
onOpenChange: (open: boolean) => void
|
||||
status: LightragStatus | null
|
||||
}
|
||||
|
||||
const StatusDialog = ({ open, onOpenChange, status }: StatusDialogProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="sm:max-w-[500px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('graphPanel.statusDialog.title')}</DialogTitle>
|
||||
</DialogHeader>
|
||||
<StatusCard status={status} />
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
export default StatusDialog
|
@@ -1,8 +1,7 @@
|
||||
import { cn } from '@/lib/utils'
|
||||
import { useBackendState } from '@/stores/state'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/Popover'
|
||||
import StatusCard from '@/components/status/StatusCard'
|
||||
import StatusDialog from './StatusDialog'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const StatusIndicator = () => {
|
||||
@@ -11,6 +10,7 @@ const StatusIndicator = () => {
|
||||
const lastCheckTime = useBackendState.use.lastCheckTime()
|
||||
const status = useBackendState.use.status()
|
||||
const [animate, setAnimate] = useState(false)
|
||||
const [dialogOpen, setDialogOpen] = useState(false)
|
||||
|
||||
// listen to health change
|
||||
useEffect(() => {
|
||||
@@ -21,9 +21,10 @@ const StatusIndicator = () => {
|
||||
|
||||
return (
|
||||
<div className="fixed right-4 bottom-4 flex items-center gap-2 opacity-80 select-none">
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<div className="flex cursor-help items-center gap-2">
|
||||
<div
|
||||
className="flex cursor-pointer items-center gap-2"
|
||||
onClick={() => setDialogOpen(true)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'h-3 w-3 rounded-full transition-all duration-300',
|
||||
@@ -38,11 +39,12 @@ const StatusIndicator = () => {
|
||||
{health ? t('graphPanel.statusIndicator.connected') : t('graphPanel.statusIndicator.disconnected')}
|
||||
</span>
|
||||
</div>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-auto" side="top" align="end">
|
||||
<StatusCard status={status} />
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
|
||||
<StatusDialog
|
||||
open={dialogOpen}
|
||||
onOpenChange={setDialogOpen}
|
||||
status={status}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
@@ -11,7 +11,7 @@ const Checkbox = React.forwardRef<
|
||||
<CheckboxPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn(
|
||||
'peer border-primary ring-offset-background focus-visible:ring-ring data-[state=checked]:bg-primary data-[state=checked]:text-primary-foreground h-4 w-4 shrink-0 rounded-sm border focus-visible:ring-2 focus-visible:ring-offset-2 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50',
|
||||
'peer border-primary ring-offset-background focus-visible:ring-ring data-[state=checked]:bg-muted data-[state=checked]:text-muted-foreground h-4 w-4 shrink-0 rounded-sm border focus-visible:ring-2 focus-visible:ring-offset-2 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
@@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<'input'>>(
|
||||
<input
|
||||
type={type}
|
||||
className={cn(
|
||||
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-100 [&::-webkit-outer-spin-button]:opacity-100',
|
||||
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-50 [&::-webkit-outer-spin-button]:opacity-50',
|
||||
className
|
||||
)}
|
||||
ref={ref}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { useState, useEffect, useCallback, useMemo, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useSettingsStore } from '@/stores/settings'
|
||||
import Button from '@/components/ui/Button'
|
||||
@@ -16,15 +16,17 @@ import EmptyCard from '@/components/ui/EmptyCard'
|
||||
import UploadDocumentsDialog from '@/components/documents/UploadDocumentsDialog'
|
||||
import ClearDocumentsDialog from '@/components/documents/ClearDocumentsDialog'
|
||||
|
||||
import { getDocuments, scanNewDocuments, DocsStatusesResponse } from '@/api/lightrag'
|
||||
import { getDocuments, scanNewDocuments, DocsStatusesResponse, DocStatus, DocStatusResponse } from '@/api/lightrag'
|
||||
import { errorMessage } from '@/lib/utils'
|
||||
import { toast } from 'sonner'
|
||||
import { useBackendState } from '@/stores/state'
|
||||
|
||||
import { RefreshCwIcon, ActivityIcon, ArrowUpIcon, ArrowDownIcon } from 'lucide-react'
|
||||
import { DocStatusResponse } from '@/api/lightrag'
|
||||
import { RefreshCwIcon, ActivityIcon, ArrowUpIcon, ArrowDownIcon, FilterIcon } from 'lucide-react'
|
||||
import PipelineStatusDialog from '@/components/documents/PipelineStatusDialog'
|
||||
|
||||
type StatusFilter = DocStatus | 'all';
|
||||
|
||||
|
||||
const getDisplayFileName = (doc: DocStatusResponse, maxLength: number = 20): string => {
|
||||
// Check if file_path exists and is a non-empty string
|
||||
if (!doc.file_path || typeof doc.file_path !== 'string' || doc.file_path.trim() === '') {
|
||||
@@ -148,6 +150,10 @@ export default function DocumentManager() {
|
||||
const [sortField, setSortField] = useState<SortField>('updated_at')
|
||||
const [sortDirection, setSortDirection] = useState<SortDirection>('desc')
|
||||
|
||||
// State for document status filter
|
||||
const [statusFilter, setStatusFilter] = useState<StatusFilter>('all');
|
||||
|
||||
|
||||
// Handle sort column click
|
||||
const handleSort = (field: SortField) => {
|
||||
if (sortField === field) {
|
||||
@@ -161,7 +167,7 @@ export default function DocumentManager() {
|
||||
}
|
||||
|
||||
// Sort documents based on current sort field and direction
|
||||
const sortDocuments = (documents: DocStatusResponse[]) => {
|
||||
const sortDocuments = useCallback((documents: DocStatusResponse[]) => {
|
||||
return [...documents].sort((a, b) => {
|
||||
let valueA, valueB;
|
||||
|
||||
@@ -188,7 +194,50 @@ export default function DocumentManager() {
|
||||
return sortMultiplier * (valueA > valueB ? 1 : valueA < valueB ? -1 : 0);
|
||||
}
|
||||
});
|
||||
}, [sortField, sortDirection, showFileName]);
|
||||
|
||||
const filteredAndSortedDocs = useMemo(() => {
|
||||
if (!docs) return null;
|
||||
|
||||
let filteredDocs = { ...docs };
|
||||
|
||||
if (statusFilter !== 'all') {
|
||||
filteredDocs = {
|
||||
...docs,
|
||||
statuses: {
|
||||
pending: [],
|
||||
processing: [],
|
||||
processed: [],
|
||||
failed: [],
|
||||
[statusFilter]: docs.statuses[statusFilter] || []
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
if (!sortField || !sortDirection) return filteredDocs;
|
||||
|
||||
const sortedStatuses = Object.entries(filteredDocs.statuses).reduce((acc, [status, documents]) => {
|
||||
const sortedDocuments = sortDocuments(documents);
|
||||
acc[status as DocStatus] = sortedDocuments;
|
||||
return acc;
|
||||
}, {} as DocsStatusesResponse['statuses']);
|
||||
|
||||
return { ...filteredDocs, statuses: sortedStatuses };
|
||||
}, [docs, sortField, sortDirection, statusFilter, sortDocuments]);
|
||||
|
||||
// Calculate document counts for each status
|
||||
const documentCounts = useMemo(() => {
|
||||
if (!docs) return { all: 0 } as Record<string, number>;
|
||||
|
||||
const counts: Record<string, number> = { all: 0 };
|
||||
|
||||
Object.entries(docs.statuses).forEach(([status, documents]) => {
|
||||
counts[status as DocStatus] = documents.length;
|
||||
counts.all += documents.length;
|
||||
});
|
||||
|
||||
return counts;
|
||||
}, [docs]);
|
||||
|
||||
// Store previous status counts
|
||||
const prevStatusCounts = useRef({
|
||||
@@ -386,8 +435,8 @@ export default function DocumentManager() {
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex-1" />
|
||||
<ClearDocumentsDialog />
|
||||
<UploadDocumentsDialog />
|
||||
<ClearDocumentsDialog onDocumentsCleared={fetchDocuments} />
|
||||
<UploadDocumentsDialog onDocumentsUploaded={fetchDocuments} />
|
||||
<PipelineStatusDialog
|
||||
open={showPipelineStatus}
|
||||
onOpenChange={setShowPipelineStatus}
|
||||
@@ -398,6 +447,65 @@ export default function DocumentManager() {
|
||||
<CardHeader className="flex-none py-2 px-4">
|
||||
<div className="flex justify-between items-center">
|
||||
<CardTitle>{t('documentPanel.documentManager.uploadedTitle')}</CardTitle>
|
||||
<div className="flex items-center gap-2">
|
||||
<FilterIcon className="h-4 w-4" />
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
size="sm"
|
||||
variant={statusFilter === 'all' ? 'secondary' : 'outline'}
|
||||
onClick={() => setStatusFilter('all')}
|
||||
className={cn(
|
||||
statusFilter === 'all' && 'bg-gray-100 dark:bg-gray-900 font-medium border border-gray-400 dark:border-gray-500 shadow-sm'
|
||||
)}
|
||||
>
|
||||
{t('documentPanel.documentManager.status.all')} ({documentCounts.all})
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={statusFilter === 'processed' ? 'secondary' : 'outline'}
|
||||
onClick={() => setStatusFilter('processed')}
|
||||
className={cn(
|
||||
documentCounts.processed > 0 ? 'text-green-600' : 'text-gray-500',
|
||||
statusFilter === 'processed' && 'bg-green-100 dark:bg-green-900/30 font-medium border border-green-400 dark:border-green-600 shadow-sm'
|
||||
)}
|
||||
>
|
||||
{t('documentPanel.documentManager.status.completed')} ({documentCounts.processed || 0})
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={statusFilter === 'processing' ? 'secondary' : 'outline'}
|
||||
onClick={() => setStatusFilter('processing')}
|
||||
className={cn(
|
||||
documentCounts.processing > 0 ? 'text-blue-600' : 'text-gray-500',
|
||||
statusFilter === 'processing' && 'bg-blue-100 dark:bg-blue-900/30 font-medium border border-blue-400 dark:border-blue-600 shadow-sm'
|
||||
)}
|
||||
>
|
||||
{t('documentPanel.documentManager.status.processing')} ({documentCounts.processing || 0})
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={statusFilter === 'pending' ? 'secondary' : 'outline'}
|
||||
onClick={() => setStatusFilter('pending')}
|
||||
className={cn(
|
||||
documentCounts.pending > 0 ? 'text-yellow-600' : 'text-gray-500',
|
||||
statusFilter === 'pending' && 'bg-yellow-100 dark:bg-yellow-900/30 font-medium border border-yellow-400 dark:border-yellow-600 shadow-sm'
|
||||
)}
|
||||
>
|
||||
{t('documentPanel.documentManager.status.pending')} ({documentCounts.pending || 0})
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant={statusFilter === 'failed' ? 'secondary' : 'outline'}
|
||||
onClick={() => setStatusFilter('failed')}
|
||||
className={cn(
|
||||
documentCounts.failed > 0 ? 'text-red-600' : 'text-gray-500',
|
||||
statusFilter === 'failed' && 'bg-red-100 dark:bg-red-900/30 font-medium border border-red-400 dark:border-red-600 shadow-sm'
|
||||
)}
|
||||
>
|
||||
{t('documentPanel.documentManager.status.failed')} ({documentCounts.failed || 0})
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-gray-500">{t('documentPanel.documentManager.fileNameLabel')}</span>
|
||||
<Button
|
||||
@@ -477,11 +585,8 @@ export default function DocumentManager() {
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody className="text-sm overflow-auto">
|
||||
{Object.entries(docs.statuses).flatMap(([status, documents]) => {
|
||||
// Apply sorting to documents
|
||||
const sortedDocuments = sortDocuments(documents);
|
||||
|
||||
return sortedDocuments.map(doc => (
|
||||
{filteredAndSortedDocs?.statuses && Object.entries(filteredAndSortedDocs.statuses).flatMap(([status, documents]) =>
|
||||
documents.map((doc) => (
|
||||
<TableRow key={doc.id}>
|
||||
<TableCell className="truncate font-mono overflow-visible max-w-[250px]">
|
||||
{showFileName ? (
|
||||
@@ -541,8 +646,8 @@ export default function DocumentManager() {
|
||||
{new Date(doc.updated_at).toLocaleString()}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
));
|
||||
})}
|
||||
)))
|
||||
}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
@@ -18,6 +18,8 @@ import GraphSearch from '@/components/graph/GraphSearch'
|
||||
import GraphLabels from '@/components/graph/GraphLabels'
|
||||
import PropertiesView from '@/components/graph/PropertiesView'
|
||||
import SettingsDisplay from '@/components/graph/SettingsDisplay'
|
||||
import Legend from '@/components/graph/Legend'
|
||||
import LegendButton from '@/components/graph/LegendButton'
|
||||
|
||||
import { useSettingsStore } from '@/stores/settings'
|
||||
import { useGraphStore } from '@/stores/graph'
|
||||
@@ -116,6 +118,7 @@ const GraphViewer = () => {
|
||||
const showPropertyPanel = useSettingsStore.use.showPropertyPanel()
|
||||
const showNodeSearchBar = useSettingsStore.use.showNodeSearchBar()
|
||||
const enableNodeDrag = useSettingsStore.use.enableNodeDrag()
|
||||
const showLegend = useSettingsStore.use.showLegend()
|
||||
|
||||
// Initialize sigma settings once on component mount
|
||||
// All dynamic settings will be updated in GraphControl using useSetSettings
|
||||
@@ -195,6 +198,7 @@ const GraphViewer = () => {
|
||||
<LayoutsControl />
|
||||
<ZoomControl />
|
||||
<FullScreenControl />
|
||||
<LegendButton />
|
||||
<Settings />
|
||||
{/* <ThemeToggle /> */}
|
||||
</div>
|
||||
@@ -205,6 +209,12 @@ const GraphViewer = () => {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showLegend && (
|
||||
<div className="absolute bottom-10 right-2">
|
||||
<Legend className="bg-background/60 backdrop-blur-lg" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* <div className="absolute bottom-2 right-2 flex flex-col rounded-xl border-2">
|
||||
<MiniMap width="100px" height="100px" />
|
||||
</div> */}
|
||||
|
@@ -51,7 +51,7 @@ const LoginPage = () => {
|
||||
|
||||
if (!status.auth_configured && status.access_token) {
|
||||
// If auth is not configured, use the guest token and redirect
|
||||
login(status.access_token, true, status.core_version, status.api_version)
|
||||
login(status.access_token, true, status.core_version, status.api_version, status.webui_title || null, status.webui_description || null)
|
||||
if (status.message) {
|
||||
toast.info(status.message)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ const LoginPage = () => {
|
||||
|
||||
// Check authentication mode
|
||||
const isGuestMode = response.auth_mode === 'disabled'
|
||||
login(response.access_token, isGuestMode, response.core_version, response.api_version)
|
||||
login(response.access_token, isGuestMode, response.core_version, response.api_version, response.webui_title || null, response.webui_description || null)
|
||||
|
||||
// Set session flag for version check
|
||||
if (response.core_version || response.api_version) {
|
||||
|
@@ -8,6 +8,7 @@ import { cn } from '@/lib/utils'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { navigationService } from '@/services/navigation'
|
||||
import { ZapIcon, GithubIcon, LogOutIcon } from 'lucide-react'
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip'
|
||||
|
||||
interface NavigationTabProps {
|
||||
value: string
|
||||
@@ -55,7 +56,7 @@ function TabsNavigation() {
|
||||
|
||||
export default function SiteHeader() {
|
||||
const { t } = useTranslation()
|
||||
const { isGuestMode, coreVersion, apiVersion, username } = useAuthStore()
|
||||
const { isGuestMode, coreVersion, apiVersion, username, webuiTitle, webuiDescription } = useAuthStore()
|
||||
|
||||
const versionDisplay = (coreVersion && apiVersion)
|
||||
? `${coreVersion}/${apiVersion}`
|
||||
@@ -67,17 +68,31 @@ export default function SiteHeader() {
|
||||
|
||||
return (
|
||||
<header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
|
||||
<div className="w-[200px] flex items-center">
|
||||
<div className="min-w-[200px] w-auto flex items-center">
|
||||
<a href={webuiPrefix} className="flex items-center gap-2">
|
||||
<ZapIcon className="size-4 text-emerald-400" aria-hidden="true" />
|
||||
{/* <img src='/logo.png' className="size-4" /> */}
|
||||
<span className="font-bold md:inline-block">{SiteInfo.name}</span>
|
||||
{versionDisplay && (
|
||||
<span className="ml-2 text-xs text-gray-500 dark:text-gray-400">
|
||||
v{versionDisplay}
|
||||
</span>
|
||||
)}
|
||||
</a>
|
||||
{webuiTitle && (
|
||||
<div className="flex items-center">
|
||||
<span className="mx-1 text-xs text-gray-500 dark:text-gray-400">|</span>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="font-medium text-sm cursor-default">
|
||||
{webuiTitle}
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
{webuiDescription && (
|
||||
<TooltipContent side="bottom">
|
||||
{webuiDescription}
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex h-10 flex-1 items-center justify-center">
|
||||
@@ -91,6 +106,11 @@ export default function SiteHeader() {
|
||||
|
||||
<nav className="w-[200px] flex items-center justify-end">
|
||||
<div className="flex items-center gap-2">
|
||||
{versionDisplay && (
|
||||
<span className="text-xs text-gray-500 dark:text-gray-400 mr-1">
|
||||
v{versionDisplay}
|
||||
</span>
|
||||
)}
|
||||
<Button variant="ghost" size="icon" side="bottom" tooltip={t('header.projectRepository')}>
|
||||
<a href={SiteInfo.github} target="_blank" rel="noopener noreferrer">
|
||||
<GithubIcon className="size-4" aria-hidden="true" />
|
||||
|
@@ -11,6 +11,35 @@ import { useSettingsStore } from '@/stores/settings'
|
||||
|
||||
import seedrandom from 'seedrandom'
|
||||
|
||||
// Helper function to generate a color based on type
|
||||
const getNodeColorByType = (nodeType: string | undefined): string => {
|
||||
const defaultColor = '#CCCCCC'; // Default color for nodes without a type or undefined type
|
||||
if (!nodeType) {
|
||||
return defaultColor;
|
||||
}
|
||||
|
||||
const typeColorMap = useGraphStore.getState().typeColorMap;
|
||||
|
||||
if (!typeColorMap.has(nodeType)) {
|
||||
// Generate a color based on the type string itself for consistency
|
||||
// Seed the global random number generator based on the node type
|
||||
seedrandom(nodeType, { global: true });
|
||||
// Call randomColor without arguments; it will use the globally seeded Math.random()
|
||||
const newColor = randomColor();
|
||||
|
||||
const newMap = new Map(typeColorMap);
|
||||
newMap.set(nodeType, newColor);
|
||||
useGraphStore.setState({ typeColorMap: newMap });
|
||||
|
||||
return newColor;
|
||||
}
|
||||
|
||||
// Restore the default random seed if necessary, though usually not required for this use case
|
||||
// seedrandom(Date.now().toString(), { global: true });
|
||||
return typeColorMap.get(nodeType) || defaultColor; // Add fallback just in case
|
||||
};
|
||||
|
||||
|
||||
const validateGraph = (graph: RawGraph) => {
|
||||
// Check if graph exists
|
||||
if (!graph) {
|
||||
@@ -68,9 +97,15 @@ export type NodeType = {
|
||||
color: string
|
||||
highlighted?: boolean
|
||||
}
|
||||
export type EdgeType = { label: string }
|
||||
export type EdgeType = {
|
||||
label: string
|
||||
originalWeight?: number
|
||||
size?: number
|
||||
color?: string
|
||||
hidden?: boolean
|
||||
}
|
||||
|
||||
const fetchGraph = async (label: string, maxDepth: number, minDegree: number) => {
|
||||
const fetchGraph = async (label: string, maxDepth: number, maxNodes: number) => {
|
||||
let rawData: any = null;
|
||||
|
||||
// Check if we need to fetch all database labels first
|
||||
@@ -89,8 +124,8 @@ const fetchGraph = async (label: string, maxDepth: number, minDegree: number) =>
|
||||
const queryLabel = label || '*';
|
||||
|
||||
try {
|
||||
console.log(`Fetching graph label: ${queryLabel}, depth: ${maxDepth}, deg: ${minDegree}`);
|
||||
rawData = await queryGraphs(queryLabel, maxDepth, minDegree);
|
||||
console.log(`Fetching graph label: ${queryLabel}, depth: ${maxDepth}, nodes: ${maxNodes}`);
|
||||
rawData = await queryGraphs(queryLabel, maxDepth, maxNodes);
|
||||
} catch (e) {
|
||||
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!');
|
||||
return null;
|
||||
@@ -106,9 +141,6 @@ const fetchGraph = async (label: string, maxDepth: number, minDegree: number) =>
|
||||
const node = rawData.nodes[i]
|
||||
nodeIdMap[node.id] = i
|
||||
|
||||
// const seed = node.labels.length > 0 ? node.labels[0] : node.id
|
||||
seedrandom(node.id, { global: true })
|
||||
node.color = randomColor()
|
||||
node.x = Math.random()
|
||||
node.y = Math.random()
|
||||
node.degree = 0
|
||||
@@ -169,11 +201,14 @@ const fetchGraph = async (label: string, maxDepth: number, minDegree: number) =>
|
||||
}
|
||||
|
||||
// console.debug({ data: JSON.parse(JSON.stringify(rawData)) })
|
||||
return rawGraph
|
||||
return { rawGraph, is_truncated: rawData.is_truncated }
|
||||
}
|
||||
|
||||
// Create a new graph instance with the raw graph data
|
||||
const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
||||
// Get edge size settings from store
|
||||
const minEdgeSize = useSettingsStore.getState().minEdgeSize
|
||||
const maxEdgeSize = useSettingsStore.getState().maxEdgeSize
|
||||
// Skip graph creation if no data or empty nodes
|
||||
if (!rawGraph || !rawGraph.nodes.length) {
|
||||
console.log('No graph data available, skipping sigma graph creation');
|
||||
@@ -204,8 +239,40 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
||||
|
||||
// Add edges from raw graph data
|
||||
for (const rawEdge of rawGraph?.edges ?? []) {
|
||||
// Get weight from edge properties or default to 1
|
||||
const weight = rawEdge.properties?.weight !== undefined ? Number(rawEdge.properties.weight) : 1
|
||||
|
||||
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
||||
label: rawEdge.properties?.keywords || undefined
|
||||
label: rawEdge.properties?.keywords || undefined,
|
||||
size: weight, // Set initial size based on weight
|
||||
originalWeight: weight, // Store original weight for recalculation
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate edge size based on weight range, similar to node size calculation
|
||||
let minWeight = Number.MAX_SAFE_INTEGER
|
||||
let maxWeight = 0
|
||||
|
||||
// Find min and max weight values
|
||||
graph.forEachEdge(edge => {
|
||||
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
|
||||
minWeight = Math.min(minWeight, weight)
|
||||
maxWeight = Math.max(maxWeight, weight)
|
||||
})
|
||||
|
||||
// Scale edge sizes based on weight range
|
||||
const weightRange = maxWeight - minWeight
|
||||
if (weightRange > 0) {
|
||||
const sizeScale = maxEdgeSize - minEdgeSize
|
||||
graph.forEachEdge(edge => {
|
||||
const weight = graph.getEdgeAttribute(edge, 'originalWeight') || 1
|
||||
const scaledSize = minEdgeSize + sizeScale * Math.pow((weight - minWeight) / weightRange, 0.5)
|
||||
graph.setEdgeAttribute(edge, 'size', scaledSize)
|
||||
})
|
||||
} else {
|
||||
// If all weights are the same, use default size
|
||||
graph.forEachEdge(edge => {
|
||||
graph.setEdgeAttribute(edge, 'size', minEdgeSize)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -218,11 +285,12 @@ const useLightrangeGraph = () => {
|
||||
const rawGraph = useGraphStore.use.rawGraph()
|
||||
const sigmaGraph = useGraphStore.use.sigmaGraph()
|
||||
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
|
||||
const minDegree = useSettingsStore.use.graphMinDegree()
|
||||
const maxNodes = useSettingsStore.use.graphMaxNodes()
|
||||
const isFetching = useGraphStore.use.isFetching()
|
||||
const nodeToExpand = useGraphStore.use.nodeToExpand()
|
||||
const nodeToPrune = useGraphStore.use.nodeToPrune()
|
||||
|
||||
|
||||
// Use ref to track if data has been loaded and initial load
|
||||
const dataLoadedRef = useRef(false)
|
||||
const initialLoadRef = useRef(false)
|
||||
@@ -292,23 +360,37 @@ const useLightrangeGraph = () => {
|
||||
// Use a local copy of the parameters
|
||||
const currentQueryLabel = queryLabel
|
||||
const currentMaxQueryDepth = maxQueryDepth
|
||||
const currentMinDegree = minDegree
|
||||
const currentMaxNodes = maxNodes
|
||||
|
||||
// Declare a variable to store data promise
|
||||
let dataPromise;
|
||||
let dataPromise: Promise<{ rawGraph: RawGraph | null; is_truncated: boolean | undefined } | null>;
|
||||
|
||||
// 1. If query label is not empty, use fetchGraph
|
||||
if (currentQueryLabel) {
|
||||
dataPromise = fetchGraph(currentQueryLabel, currentMaxQueryDepth, currentMinDegree);
|
||||
dataPromise = fetchGraph(currentQueryLabel, currentMaxQueryDepth, currentMaxNodes);
|
||||
} else {
|
||||
// 2. If query label is empty, set data to null
|
||||
console.log('Query label is empty, show empty graph')
|
||||
dataPromise = Promise.resolve(null);
|
||||
dataPromise = Promise.resolve({ rawGraph: null, is_truncated: false });
|
||||
}
|
||||
|
||||
// 3. Process data
|
||||
dataPromise.then((data) => {
|
||||
dataPromise.then((result) => {
|
||||
const state = useGraphStore.getState()
|
||||
const data = result?.rawGraph;
|
||||
|
||||
// Assign colors based on entity_type *after* fetching
|
||||
if (data && data.nodes) {
|
||||
data.nodes.forEach(node => {
|
||||
// Use entity_type instead of type
|
||||
const nodeEntityType = node.properties?.entity_type as string | undefined;
|
||||
node.color = getNodeColorByType(nodeEntityType);
|
||||
});
|
||||
}
|
||||
|
||||
if (result?.is_truncated) {
|
||||
toast.info(t('graphPanel.dataIsTruncated', 'Graph data is truncated to Max Nodes'));
|
||||
}
|
||||
|
||||
// Reset state
|
||||
state.reset()
|
||||
@@ -336,15 +418,23 @@ const useLightrangeGraph = () => {
|
||||
// Still mark graph as empty for other logic
|
||||
state.setGraphIsEmpty(true);
|
||||
|
||||
// Only clear current label if it's not already empty
|
||||
if (currentQueryLabel) {
|
||||
// Check if the empty graph is due to 401 authentication error
|
||||
const errorMessage = useBackendState.getState().message;
|
||||
const isAuthError = errorMessage && errorMessage.includes('Authentication required');
|
||||
|
||||
// Only clear queryLabel if it's not an auth error and current label is not empty
|
||||
if (!isAuthError && currentQueryLabel) {
|
||||
useSettingsStore.getState().setQueryLabel('');
|
||||
}
|
||||
|
||||
// Clear last successful query label to ensure labels are fetched next time
|
||||
// Only clear last successful query label if it's not an auth error
|
||||
if (!isAuthError) {
|
||||
state.setLastSuccessfulQueryLabel('');
|
||||
} else {
|
||||
console.log('Keep queryLabel for post-login reload');
|
||||
}
|
||||
|
||||
console.log('Graph data is empty, created graph with empty graph node');
|
||||
console.log(`Graph data is empty, created graph with empty graph node. Auth error: ${isAuthError}`);
|
||||
} else {
|
||||
// Create and set new graph
|
||||
const newSigmaGraph = createSigmaGraph(data);
|
||||
@@ -384,7 +474,7 @@ const useLightrangeGraph = () => {
|
||||
state.setLastSuccessfulQueryLabel('') // Clear last successful query label on error
|
||||
})
|
||||
}
|
||||
}, [queryLabel, maxQueryDepth, minDegree, isFetching, t])
|
||||
}, [queryLabel, maxQueryDepth, maxNodes, isFetching, t])
|
||||
|
||||
// Handle node expansion
|
||||
useEffect(() => {
|
||||
@@ -407,7 +497,7 @@ const useLightrangeGraph = () => {
|
||||
}
|
||||
|
||||
// Fetch the extended subgraph with depth 2
|
||||
const extendedGraph = await queryGraphs(label, 2, 0);
|
||||
const extendedGraph = await queryGraphs(label, 2, 1000);
|
||||
|
||||
if (!extendedGraph || !extendedGraph.nodes || !extendedGraph.edges) {
|
||||
console.error('Failed to fetch extended graph');
|
||||
|
@@ -32,14 +32,24 @@
|
||||
"authDisabled": "تم تعطيل المصادقة. استخدام وضع بدون تسجيل دخول.",
|
||||
"guestMode": "وضع بدون تسجيل دخول"
|
||||
},
|
||||
"common": {
|
||||
"cancel": "إلغاء"
|
||||
},
|
||||
"documentPanel": {
|
||||
"clearDocuments": {
|
||||
"button": "مسح",
|
||||
"tooltip": "مسح المستندات",
|
||||
"title": "مسح المستندات",
|
||||
"description": "سيؤدي هذا إلى إزالة جميع المستندات من النظام",
|
||||
"warning": "تحذير: سيؤدي هذا الإجراء إلى حذف جميع المستندات بشكل دائم ولا يمكن التراجع عنه!",
|
||||
"confirm": "هل تريد حقًا مسح جميع المستندات؟",
|
||||
"confirmPrompt": "اكتب 'yes' لتأكيد هذا الإجراء",
|
||||
"confirmPlaceholder": "اكتب yes للتأكيد",
|
||||
"clearCache": "مسح كاش نموذج اللغة",
|
||||
"confirmButton": "نعم",
|
||||
"success": "تم مسح المستندات بنجاح",
|
||||
"cacheCleared": "تم مسح ذاكرة التخزين المؤقت بنجاح",
|
||||
"cacheClearFailed": "فشل مسح ذاكرة التخزين المؤقت:\n{{error}}",
|
||||
"failed": "فشل مسح المستندات:\n{{message}}",
|
||||
"error": "فشل مسح المستندات:\n{{error}}"
|
||||
},
|
||||
@@ -95,6 +105,7 @@
|
||||
"metadata": "البيانات الوصفية"
|
||||
},
|
||||
"status": {
|
||||
"all": "الكل",
|
||||
"completed": "مكتمل",
|
||||
"processing": "قيد المعالجة",
|
||||
"pending": "معلق",
|
||||
@@ -127,6 +138,11 @@
|
||||
}
|
||||
},
|
||||
"graphPanel": {
|
||||
"dataIsTruncated": "تم اقتصار بيانات الرسم البياني على الحد الأقصى للعقد",
|
||||
"statusDialog": {
|
||||
"title": "إعدادات خادم LightRAG"
|
||||
},
|
||||
"legend": "المفتاح",
|
||||
"sideBar": {
|
||||
"settings": {
|
||||
"settings": "الإعدادات",
|
||||
@@ -139,9 +155,12 @@
|
||||
"hideUnselectedEdges": "إخفاء الحواف غير المحددة",
|
||||
"edgeEvents": "أحداث الحافة",
|
||||
"maxQueryDepth": "أقصى عمق للاستعلام",
|
||||
"minDegree": "الدرجة الدنيا",
|
||||
"maxNodes": "الحد الأقصى للعقد",
|
||||
"maxLayoutIterations": "أقصى تكرارات التخطيط",
|
||||
"depth": "العمق",
|
||||
"resetToDefault": "إعادة التعيين إلى الافتراضي",
|
||||
"edgeSizeRange": "نطاق حجم الحافة",
|
||||
"depth": "D",
|
||||
"max": "Max",
|
||||
"degree": "الدرجة",
|
||||
"apiKey": "مفتاح واجهة برمجة التطبيقات",
|
||||
"enterYourAPIkey": "أدخل مفتاح واجهة برمجة التطبيقات الخاص بك",
|
||||
@@ -171,6 +190,9 @@
|
||||
"fullScreenControl": {
|
||||
"fullScreen": "شاشة كاملة",
|
||||
"windowed": "نوافذ"
|
||||
},
|
||||
"legendControl": {
|
||||
"toggleLegend": "تبديل المفتاح"
|
||||
}
|
||||
},
|
||||
"statusIndicator": {
|
||||
|
@@ -32,14 +32,24 @@
|
||||
"authDisabled": "Authentication is disabled. Using login free mode.",
|
||||
"guestMode": "Login Free"
|
||||
},
|
||||
"common": {
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"documentPanel": {
|
||||
"clearDocuments": {
|
||||
"button": "Clear",
|
||||
"tooltip": "Clear documents",
|
||||
"title": "Clear Documents",
|
||||
"description": "This will remove all documents from the system",
|
||||
"warning": "WARNING: This action will permanently delete all documents and cannot be undone!",
|
||||
"confirm": "Do you really want to clear all documents?",
|
||||
"confirmPrompt": "Type 'yes' to confirm this action",
|
||||
"confirmPlaceholder": "Type yes to confirm",
|
||||
"clearCache": "Clear LLM cache",
|
||||
"confirmButton": "YES",
|
||||
"success": "Documents cleared successfully",
|
||||
"cacheCleared": "Cache cleared successfully",
|
||||
"cacheClearFailed": "Failed to clear cache:\n{{error}}",
|
||||
"failed": "Clear Documents Failed:\n{{message}}",
|
||||
"error": "Clear Documents Failed:\n{{error}}"
|
||||
},
|
||||
@@ -95,6 +105,7 @@
|
||||
"metadata": "Metadata"
|
||||
},
|
||||
"status": {
|
||||
"all": "All",
|
||||
"completed": "Completed",
|
||||
"processing": "Processing",
|
||||
"pending": "Pending",
|
||||
@@ -127,6 +138,11 @@
|
||||
}
|
||||
},
|
||||
"graphPanel": {
|
||||
"dataIsTruncated": "Graph data is truncated to Max Nodes",
|
||||
"statusDialog": {
|
||||
"title": "LightRAG Server Settings"
|
||||
},
|
||||
"legend": "Legend",
|
||||
"sideBar": {
|
||||
"settings": {
|
||||
"settings": "Settings",
|
||||
@@ -139,9 +155,12 @@
|
||||
"hideUnselectedEdges": "Hide Unselected Edges",
|
||||
"edgeEvents": "Edge Events",
|
||||
"maxQueryDepth": "Max Query Depth",
|
||||
"minDegree": "Minimum Degree",
|
||||
"maxNodes": "Max Nodes",
|
||||
"maxLayoutIterations": "Max Layout Iterations",
|
||||
"depth": "Depth",
|
||||
"resetToDefault": "Reset to default",
|
||||
"edgeSizeRange": "Edge Size Range",
|
||||
"depth": "D",
|
||||
"max": "Max",
|
||||
"degree": "Degree",
|
||||
"apiKey": "API Key",
|
||||
"enterYourAPIkey": "Enter your API key",
|
||||
@@ -171,6 +190,9 @@
|
||||
"fullScreenControl": {
|
||||
"fullScreen": "Full Screen",
|
||||
"windowed": "Windowed"
|
||||
},
|
||||
"legendControl": {
|
||||
"toggleLegend": "Toggle Legend"
|
||||
}
|
||||
},
|
||||
"statusIndicator": {
|
||||
|
@@ -32,14 +32,24 @@
|
||||
"authDisabled": "L'authentification est désactivée. Utilisation du mode sans connexion.",
|
||||
"guestMode": "Mode sans connexion"
|
||||
},
|
||||
"common": {
|
||||
"cancel": "Annuler"
|
||||
},
|
||||
"documentPanel": {
|
||||
"clearDocuments": {
|
||||
"button": "Effacer",
|
||||
"tooltip": "Effacer les documents",
|
||||
"title": "Effacer les documents",
|
||||
"description": "Cette action supprimera tous les documents du système",
|
||||
"warning": "ATTENTION : Cette action supprimera définitivement tous les documents et ne peut pas être annulée !",
|
||||
"confirm": "Voulez-vous vraiment effacer tous les documents ?",
|
||||
"confirmPrompt": "Tapez 'yes' pour confirmer cette action",
|
||||
"confirmPlaceholder": "Tapez yes pour confirmer",
|
||||
"clearCache": "Effacer le cache LLM",
|
||||
"confirmButton": "OUI",
|
||||
"success": "Documents effacés avec succès",
|
||||
"cacheCleared": "Cache effacé avec succès",
|
||||
"cacheClearFailed": "Échec de l'effacement du cache :\n{{error}}",
|
||||
"failed": "Échec de l'effacement des documents :\n{{message}}",
|
||||
"error": "Échec de l'effacement des documents :\n{{error}}"
|
||||
},
|
||||
@@ -95,6 +105,7 @@
|
||||
"metadata": "Métadonnées"
|
||||
},
|
||||
"status": {
|
||||
"all": "Tous",
|
||||
"completed": "Terminé",
|
||||
"processing": "En traitement",
|
||||
"pending": "En attente",
|
||||
@@ -127,6 +138,11 @@
|
||||
}
|
||||
},
|
||||
"graphPanel": {
|
||||
"dataIsTruncated": "Les données du graphe sont tronquées au nombre maximum de nœuds",
|
||||
"statusDialog": {
|
||||
"title": "Paramètres du Serveur LightRAG"
|
||||
},
|
||||
"legend": "Légende",
|
||||
"sideBar": {
|
||||
"settings": {
|
||||
"settings": "Paramètres",
|
||||
@@ -139,9 +155,12 @@
|
||||
"hideUnselectedEdges": "Masquer les arêtes non sélectionnées",
|
||||
"edgeEvents": "Événements des arêtes",
|
||||
"maxQueryDepth": "Profondeur maximale de la requête",
|
||||
"minDegree": "Degré minimum",
|
||||
"maxNodes": "Nombre maximum de nœuds",
|
||||
"maxLayoutIterations": "Itérations maximales de mise en page",
|
||||
"depth": "Profondeur",
|
||||
"resetToDefault": "Réinitialiser par défaut",
|
||||
"edgeSizeRange": "Plage de taille des arêtes",
|
||||
"depth": "D",
|
||||
"max": "Max",
|
||||
"degree": "Degré",
|
||||
"apiKey": "Clé API",
|
||||
"enterYourAPIkey": "Entrez votre clé API",
|
||||
@@ -171,6 +190,9 @@
|
||||
"fullScreenControl": {
|
||||
"fullScreen": "Plein écran",
|
||||
"windowed": "Fenêtré"
|
||||
},
|
||||
"legendControl": {
|
||||
"toggleLegend": "Basculer la légende"
|
||||
}
|
||||
},
|
||||
"statusIndicator": {
|
||||
|
@@ -32,14 +32,24 @@
|
||||
"authDisabled": "认证已禁用,使用无需登陆模式。",
|
||||
"guestMode": "无需登陆"
|
||||
},
|
||||
"common": {
|
||||
"cancel": "取消"
|
||||
},
|
||||
"documentPanel": {
|
||||
"clearDocuments": {
|
||||
"button": "清空",
|
||||
"tooltip": "清空文档",
|
||||
"title": "清空文档",
|
||||
"description": "此操作将从系统中移除所有文档",
|
||||
"warning": "警告:此操作将永久删除所有文档,无法恢复!",
|
||||
"confirm": "确定要清空所有文档吗?",
|
||||
"confirmPrompt": "请输入 yes 确认操作",
|
||||
"confirmPlaceholder": "输入 yes 确认",
|
||||
"clearCache": "清空LLM缓存",
|
||||
"confirmButton": "确定",
|
||||
"success": "文档清空成功",
|
||||
"cacheCleared": "缓存清空成功",
|
||||
"cacheClearFailed": "清空缓存失败:\n{{error}}",
|
||||
"failed": "清空文档失败:\n{{message}}",
|
||||
"error": "清空文档失败:\n{{error}}"
|
||||
},
|
||||
@@ -95,6 +105,7 @@
|
||||
"metadata": "元数据"
|
||||
},
|
||||
"status": {
|
||||
"all": "全部",
|
||||
"completed": "已完成",
|
||||
"processing": "处理中",
|
||||
"pending": "等待中",
|
||||
@@ -127,6 +138,11 @@
|
||||
}
|
||||
},
|
||||
"graphPanel": {
|
||||
"dataIsTruncated": "图数据已截断至最大返回节点数",
|
||||
"statusDialog": {
|
||||
"title": "LightRAG 服务器设置"
|
||||
},
|
||||
"legend": "图例",
|
||||
"sideBar": {
|
||||
"settings": {
|
||||
"settings": "设置",
|
||||
@@ -139,9 +155,12 @@
|
||||
"hideUnselectedEdges": "隐藏未选中的边",
|
||||
"edgeEvents": "边事件",
|
||||
"maxQueryDepth": "最大查询深度",
|
||||
"minDegree": "最小邻边数",
|
||||
"maxNodes": "最大返回节点数",
|
||||
"maxLayoutIterations": "最大布局迭代次数",
|
||||
"depth": "深度",
|
||||
"resetToDefault": "重置为默认值",
|
||||
"edgeSizeRange": "边粗细范围",
|
||||
"depth": "深",
|
||||
"max": "Max",
|
||||
"degree": "邻边",
|
||||
"apiKey": "API密钥",
|
||||
"enterYourAPIkey": "输入您的API密钥",
|
||||
@@ -171,6 +190,9 @@
|
||||
"fullScreenControl": {
|
||||
"fullScreen": "全屏",
|
||||
"windowed": "窗口"
|
||||
},
|
||||
"legendControl": {
|
||||
"toggleLegend": "切换图例显示"
|
||||
}
|
||||
},
|
||||
"statusIndicator": {
|
||||
|
@@ -32,7 +32,7 @@ class NavigationService {
|
||||
// Reset backend state
|
||||
useBackendState.getState().clear();
|
||||
|
||||
// Reset retrieval history while preserving other user preferences
|
||||
// Reset retrieval history message while preserving other user preferences
|
||||
useSettingsStore.getState().setRetrievalHistory([]);
|
||||
|
||||
// Clear authentication state
|
||||
|
@@ -77,6 +77,8 @@ interface GraphState {
|
||||
graphIsEmpty: boolean
|
||||
lastSuccessfulQueryLabel: string
|
||||
|
||||
typeColorMap: Map<string, string>
|
||||
|
||||
// Global flags to track data fetching attempts
|
||||
graphDataFetchAttempted: boolean
|
||||
labelsFetchAttempted: boolean
|
||||
@@ -136,6 +138,8 @@ const useGraphStoreBase = create<GraphState>()((set) => ({
|
||||
sigmaInstance: null,
|
||||
allDatabaseLabels: ['*'],
|
||||
|
||||
typeColorMap: new Map<string, string>(),
|
||||
|
||||
searchEngine: null,
|
||||
|
||||
setGraphIsEmpty: (isEmpty: boolean) => set({ graphIsEmpty: isEmpty }),
|
||||
@@ -166,7 +170,6 @@ const useGraphStoreBase = create<GraphState>()((set) => ({
|
||||
searchEngine: null,
|
||||
moveToSelectedNode: false,
|
||||
graphIsEmpty: false
|
||||
// Do not reset lastSuccessfulQueryLabel here as it's used to track query history
|
||||
});
|
||||
},
|
||||
|
||||
@@ -199,6 +202,8 @@ const useGraphStoreBase = create<GraphState>()((set) => ({
|
||||
|
||||
setSigmaInstance: (instance: any) => set({ sigmaInstance: instance }),
|
||||
|
||||
setTypeColorMap: (typeColorMap: Map<string, string>) => set({ typeColorMap }),
|
||||
|
||||
setSearchEngine: (engine: MiniSearch | null) => set({ searchEngine: engine }),
|
||||
resetSearchEngine: () => set({ searchEngine: null }),
|
||||
|
||||
|
@@ -16,6 +16,8 @@ interface SettingsState {
|
||||
// Graph viewer settings
|
||||
showPropertyPanel: boolean
|
||||
showNodeSearchBar: boolean
|
||||
showLegend: boolean
|
||||
setShowLegend: (show: boolean) => void
|
||||
|
||||
showNodeLabel: boolean
|
||||
enableNodeDrag: boolean
|
||||
@@ -24,11 +26,17 @@ interface SettingsState {
|
||||
enableHideUnselectedEdges: boolean
|
||||
enableEdgeEvents: boolean
|
||||
|
||||
minEdgeSize: number
|
||||
setMinEdgeSize: (size: number) => void
|
||||
|
||||
maxEdgeSize: number
|
||||
setMaxEdgeSize: (size: number) => void
|
||||
|
||||
graphQueryMaxDepth: number
|
||||
setGraphQueryMaxDepth: (depth: number) => void
|
||||
|
||||
graphMinDegree: number
|
||||
setGraphMinDegree: (degree: number) => void
|
||||
graphMaxNodes: number
|
||||
setGraphMaxNodes: (nodes: number) => void
|
||||
|
||||
graphLayoutMaxIterations: number
|
||||
setGraphLayoutMaxIterations: (iterations: number) => void
|
||||
@@ -68,6 +76,7 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||
language: 'en',
|
||||
showPropertyPanel: true,
|
||||
showNodeSearchBar: true,
|
||||
showLegend: false,
|
||||
|
||||
showNodeLabel: true,
|
||||
enableNodeDrag: true,
|
||||
@@ -76,8 +85,11 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||
enableHideUnselectedEdges: true,
|
||||
enableEdgeEvents: false,
|
||||
|
||||
minEdgeSize: 1,
|
||||
maxEdgeSize: 1,
|
||||
|
||||
graphQueryMaxDepth: 3,
|
||||
graphMinDegree: 0,
|
||||
graphMaxNodes: 1000,
|
||||
graphLayoutMaxIterations: 15,
|
||||
|
||||
queryLabel: defaultQueryLabel,
|
||||
@@ -130,7 +142,11 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||
|
||||
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
|
||||
|
||||
setGraphMinDegree: (degree: number) => set({ graphMinDegree: degree }),
|
||||
setGraphMaxNodes: (nodes: number) => set({ graphMaxNodes: nodes }),
|
||||
|
||||
setMinEdgeSize: (size: number) => set({ minEdgeSize: size }),
|
||||
|
||||
setMaxEdgeSize: (size: number) => set({ maxEdgeSize: size }),
|
||||
|
||||
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
|
||||
|
||||
@@ -145,12 +161,13 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||
querySettings: { ...state.querySettings, ...settings }
|
||||
})),
|
||||
|
||||
setShowFileName: (show: boolean) => set({ showFileName: show })
|
||||
setShowFileName: (show: boolean) => set({ showFileName: show }),
|
||||
setShowLegend: (show: boolean) => set({ showLegend: show })
|
||||
}),
|
||||
{
|
||||
name: 'settings-storage',
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
version: 9,
|
||||
version: 11,
|
||||
migrate: (state: any, version: number) => {
|
||||
if (version < 2) {
|
||||
state.showEdgeLabel = false
|
||||
@@ -196,6 +213,14 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||
if (version < 9) {
|
||||
state.showFileName = false
|
||||
}
|
||||
if (version < 10) {
|
||||
delete state.graphMinDegree // 删除废弃参数
|
||||
state.graphMaxNodes = 1000 // 添加新参数
|
||||
}
|
||||
if (version < 11) {
|
||||
state.minEdgeSize = 1
|
||||
state.maxEdgeSize = 1
|
||||
}
|
||||
return state
|
||||
}
|
||||
}
|
||||
|
@@ -22,10 +22,13 @@ interface AuthState {
|
||||
coreVersion: string | null;
|
||||
apiVersion: string | null;
|
||||
username: string | null; // login username
|
||||
webuiTitle: string | null; // Custom title
|
||||
webuiDescription: string | null; // Title description
|
||||
|
||||
login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null) => void;
|
||||
login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null, webuiTitle?: string | null, webuiDescription?: string | null) => void;
|
||||
logout: () => void;
|
||||
setVersion: (coreVersion: string | null, apiVersion: string | null) => void;
|
||||
setCustomTitle: (webuiTitle: string | null, webuiDescription: string | null) => void;
|
||||
}
|
||||
|
||||
const useBackendStateStoreBase = create<BackendState>()((set) => ({
|
||||
@@ -47,6 +50,14 @@ const useBackendStateStoreBase = create<BackendState>()((set) => ({
|
||||
);
|
||||
}
|
||||
|
||||
// Update custom title information if health check returns it
|
||||
if ('webui_title' in health || 'webui_description' in health) {
|
||||
useAuthStore.getState().setCustomTitle(
|
||||
'webui_title' in health ? (health.webui_title ?? null) : null,
|
||||
'webui_description' in health ? (health.webui_description ?? null) : null
|
||||
);
|
||||
}
|
||||
|
||||
set({
|
||||
health: true,
|
||||
message: null,
|
||||
@@ -107,10 +118,12 @@ const isGuestToken = (token: string): boolean => {
|
||||
return payload.role === 'guest';
|
||||
};
|
||||
|
||||
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null } => {
|
||||
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null } => {
|
||||
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
|
||||
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
|
||||
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
const username = token ? getUsernameFromToken(token) : null;
|
||||
|
||||
if (!token) {
|
||||
@@ -120,6 +133,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
username: null,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -129,6 +144,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
username: username,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -142,8 +159,10 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
coreVersion: initialState.coreVersion,
|
||||
apiVersion: initialState.apiVersion,
|
||||
username: initialState.username,
|
||||
webuiTitle: initialState.webuiTitle,
|
||||
webuiDescription: initialState.webuiDescription,
|
||||
|
||||
login: (token, isGuest = false, coreVersion = null, apiVersion = null) => {
|
||||
login: (token, isGuest = false, coreVersion = null, apiVersion = null, webuiTitle = null, webuiDescription = null) => {
|
||||
localStorage.setItem('LIGHTRAG-API-TOKEN', token);
|
||||
|
||||
if (coreVersion) {
|
||||
@@ -153,6 +172,18 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
localStorage.setItem('LIGHTRAG-API-VERSION', apiVersion);
|
||||
}
|
||||
|
||||
if (webuiTitle) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle);
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE');
|
||||
}
|
||||
|
||||
if (webuiDescription) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription);
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
}
|
||||
|
||||
const username = getUsernameFromToken(token);
|
||||
set({
|
||||
isAuthenticated: true,
|
||||
@@ -160,6 +191,8 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
username: username,
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
});
|
||||
},
|
||||
|
||||
@@ -168,6 +201,8 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
|
||||
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
|
||||
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
|
||||
set({
|
||||
isAuthenticated: false,
|
||||
@@ -175,6 +210,8 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
username: null,
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
});
|
||||
},
|
||||
|
||||
@@ -192,6 +229,27 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion
|
||||
});
|
||||
},
|
||||
|
||||
setCustomTitle: (webuiTitle, webuiDescription) => {
|
||||
// Update localStorage
|
||||
if (webuiTitle) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle);
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE');
|
||||
}
|
||||
|
||||
if (webuiDescription) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription);
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
}
|
||||
|
||||
// Update state
|
||||
set({
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription
|
||||
});
|
||||
}
|
||||
};
|
||||
});
|
||||
|
440
tests/test_graph_storage.py
Normal file
440
tests/test_graph_storage.py
Normal file
@@ -0,0 +1,440 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
通用图存储测试程序
|
||||
|
||||
该程序根据.env中的LIGHTRAG_GRAPH_STORAGE配置选择使用的图存储类型,
|
||||
并对其进行基本操作和高级操作的测试。
|
||||
|
||||
支持的图存储类型包括:
|
||||
- NetworkXStorage
|
||||
- Neo4JStorage
|
||||
- PGGraphStorage
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import importlib
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from ascii_colors import ASCIIColors
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from lightrag.kg import (
|
||||
STORAGE_IMPLEMENTATIONS,
|
||||
STORAGE_ENV_REQUIREMENTS,
|
||||
STORAGES,
|
||||
verify_storage_implementation,
|
||||
)
|
||||
from lightrag.kg.shared_storage import initialize_share_data
|
||||
|
||||
|
||||
# 模拟的嵌入函数,返回随机向量
|
||||
async def mock_embedding_func(texts):
|
||||
return np.random.rand(len(texts), 10) # 返回10维随机向量
|
||||
|
||||
|
||||
def check_env_file():
|
||||
"""
|
||||
检查.env文件是否存在,如果不存在则发出警告
|
||||
返回True表示应该继续执行,False表示应该退出
|
||||
"""
|
||||
if not os.path.exists(".env"):
|
||||
warning_msg = "警告: 当前目录中没有找到.env文件,这可能会影响存储配置的加载。"
|
||||
ASCIIColors.yellow(warning_msg)
|
||||
|
||||
# 检查是否在交互式终端中运行
|
||||
if sys.stdin.isatty():
|
||||
response = input("是否继续执行? (yes/no): ")
|
||||
if response.lower() != "yes":
|
||||
ASCIIColors.red("测试程序已取消")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def initialize_graph_storage():
|
||||
"""
|
||||
根据环境变量初始化相应的图存储实例
|
||||
返回初始化的存储实例
|
||||
"""
|
||||
# 从环境变量中获取图存储类型
|
||||
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
|
||||
|
||||
# 验证存储类型是否有效
|
||||
try:
|
||||
verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
|
||||
except ValueError as e:
|
||||
ASCIIColors.red(f"错误: {str(e)}")
|
||||
ASCIIColors.yellow(
|
||||
f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 检查所需的环境变量
|
||||
required_env_vars = STORAGE_ENV_REQUIREMENTS.get(graph_storage_type, [])
|
||||
missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
|
||||
|
||||
if missing_env_vars:
|
||||
ASCIIColors.red(
|
||||
f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 动态导入相应的模块
|
||||
module_path = STORAGES.get(graph_storage_type)
|
||||
if not module_path:
|
||||
ASCIIColors.red(f"错误: 未找到 {graph_storage_type} 的模块路径")
|
||||
return None
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_path, package="lightrag")
|
||||
storage_class = getattr(module, graph_storage_type)
|
||||
except (ImportError, AttributeError) as e:
|
||||
ASCIIColors.red(f"错误: 导入 {graph_storage_type} 失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 初始化存储实例
|
||||
global_config = {
|
||||
"embedding_batch_num": 10, # 批处理大小
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.5 # 余弦相似度阈值
|
||||
},
|
||||
"working_dir": os.environ.get("WORKING_DIR", "./rag_storage"), # 工作目录
|
||||
}
|
||||
|
||||
# 如果使用 NetworkXStorage,需要先初始化 shared_storage
|
||||
if graph_storage_type == "NetworkXStorage":
|
||||
initialize_share_data() # 使用单进程模式
|
||||
|
||||
try:
|
||||
storage = storage_class(
|
||||
namespace="test_graph",
|
||||
global_config=global_config,
|
||||
embedding_func=mock_embedding_func,
|
||||
)
|
||||
|
||||
# 初始化连接
|
||||
await storage.initialize()
|
||||
return storage
|
||||
except Exception as e:
|
||||
ASCIIColors.red(f"错误: 初始化 {graph_storage_type} 失败: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def test_graph_basic(storage):
|
||||
"""
|
||||
测试图数据库的基本操作:
|
||||
1. 使用 upsert_node 插入两个节点
|
||||
2. 使用 upsert_edge 插入一条连接两个节点的边
|
||||
3. 使用 get_node 读取一个节点
|
||||
4. 使用 get_edge 读取一条边
|
||||
"""
|
||||
try:
|
||||
# 清理之前的测试数据
|
||||
print("清理之前的测试数据...")
|
||||
await storage.drop()
|
||||
|
||||
# 1. 插入第一个节点
|
||||
node1_id = "人工智能"
|
||||
node1_data = {
|
||||
"entity_id": node1_id,
|
||||
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
|
||||
"keywords": "AI,机器学习,深度学习",
|
||||
"entity_type": "技术领域",
|
||||
}
|
||||
print(f"插入节点1: {node1_id}")
|
||||
await storage.upsert_node(node1_id, node1_data)
|
||||
|
||||
# 2. 插入第二个节点
|
||||
node2_id = "机器学习"
|
||||
node2_data = {
|
||||
"entity_id": node2_id,
|
||||
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
|
||||
"keywords": "监督学习,无监督学习,强化学习",
|
||||
"entity_type": "技术领域",
|
||||
}
|
||||
print(f"插入节点2: {node2_id}")
|
||||
await storage.upsert_node(node2_id, node2_data)
|
||||
|
||||
# 3. 插入连接边
|
||||
edge_data = {
|
||||
"relationship": "包含",
|
||||
"weight": 1.0,
|
||||
"description": "人工智能领域包含机器学习这个子领域",
|
||||
}
|
||||
print(f"插入边: {node1_id} -> {node2_id}")
|
||||
await storage.upsert_edge(node1_id, node2_id, edge_data)
|
||||
|
||||
# 4. 读取节点属性
|
||||
print(f"读取节点属性: {node1_id}")
|
||||
node1_props = await storage.get_node(node1_id)
|
||||
if node1_props:
|
||||
print(f"成功读取节点属性: {node1_id}")
|
||||
print(f"节点描述: {node1_props.get('description', '无描述')}")
|
||||
print(f"节点类型: {node1_props.get('entity_type', '无类型')}")
|
||||
print(f"节点关键词: {node1_props.get('keywords', '无关键词')}")
|
||||
# 验证返回的属性是否正确
|
||||
assert (
|
||||
node1_props.get("entity_id") == node1_id
|
||||
), f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
|
||||
assert (
|
||||
node1_props.get("description") == node1_data["description"]
|
||||
), "节点描述不匹配"
|
||||
assert (
|
||||
node1_props.get("entity_type") == node1_data["entity_type"]
|
||||
), "节点类型不匹配"
|
||||
else:
|
||||
print(f"读取节点属性失败: {node1_id}")
|
||||
assert False, f"未能读取节点属性: {node1_id}"
|
||||
|
||||
# 5. 读取边属性
|
||||
print(f"读取边属性: {node1_id} -> {node2_id}")
|
||||
edge_props = await storage.get_edge(node1_id, node2_id)
|
||||
if edge_props:
|
||||
print(f"成功读取边属性: {node1_id} -> {node2_id}")
|
||||
print(f"边关系: {edge_props.get('relationship', '无关系')}")
|
||||
print(f"边描述: {edge_props.get('description', '无描述')}")
|
||||
print(f"边权重: {edge_props.get('weight', '无权重')}")
|
||||
# 验证返回的属性是否正确
|
||||
assert (
|
||||
edge_props.get("relationship") == edge_data["relationship"]
|
||||
), "边关系不匹配"
|
||||
assert (
|
||||
edge_props.get("description") == edge_data["description"]
|
||||
), "边描述不匹配"
|
||||
assert edge_props.get("weight") == edge_data["weight"], "边权重不匹配"
|
||||
else:
|
||||
print(f"读取边属性失败: {node1_id} -> {node2_id}")
|
||||
assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
|
||||
|
||||
print("基本测试完成,数据已保留在数据库中")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_graph_advanced(storage):
|
||||
"""
|
||||
测试图数据库的高级操作:
|
||||
1. 使用 node_degree 获取节点的度数
|
||||
2. 使用 edge_degree 获取边的度数
|
||||
3. 使用 get_node_edges 获取节点的所有边
|
||||
4. 使用 get_all_labels 获取所有标签
|
||||
5. 使用 get_knowledge_graph 获取知识图谱
|
||||
6. 使用 delete_node 删除节点
|
||||
7. 使用 remove_nodes 批量删除节点
|
||||
8. 使用 remove_edges 删除边
|
||||
9. 使用 drop 清理数据
|
||||
"""
|
||||
try:
|
||||
# 清理之前的测试数据
|
||||
print("清理之前的测试数据...\n")
|
||||
await storage.drop()
|
||||
|
||||
# 1. 插入测试数据
|
||||
# 插入节点1: 人工智能
|
||||
node1_id = "人工智能"
|
||||
node1_data = {
|
||||
"entity_id": node1_id,
|
||||
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
|
||||
"keywords": "AI,机器学习,深度学习",
|
||||
"entity_type": "技术领域",
|
||||
}
|
||||
print(f"插入节点1: {node1_id}")
|
||||
await storage.upsert_node(node1_id, node1_data)
|
||||
|
||||
# 插入节点2: 机器学习
|
||||
node2_id = "机器学习"
|
||||
node2_data = {
|
||||
"entity_id": node2_id,
|
||||
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
|
||||
"keywords": "监督学习,无监督学习,强化学习",
|
||||
"entity_type": "技术领域",
|
||||
}
|
||||
print(f"插入节点2: {node2_id}")
|
||||
await storage.upsert_node(node2_id, node2_data)
|
||||
|
||||
# 插入节点3: 深度学习
|
||||
node3_id = "深度学习"
|
||||
node3_data = {
|
||||
"entity_id": node3_id,
|
||||
"description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
|
||||
"keywords": "神经网络,CNN,RNN",
|
||||
"entity_type": "技术领域",
|
||||
}
|
||||
print(f"插入节点3: {node3_id}")
|
||||
await storage.upsert_node(node3_id, node3_data)
|
||||
|
||||
# 插入边1: 人工智能 -> 机器学习
|
||||
edge1_data = {
|
||||
"relationship": "包含",
|
||||
"weight": 1.0,
|
||||
"description": "人工智能领域包含机器学习这个子领域",
|
||||
}
|
||||
print(f"插入边1: {node1_id} -> {node2_id}")
|
||||
await storage.upsert_edge(node1_id, node2_id, edge1_data)
|
||||
|
||||
# 插入边2: 机器学习 -> 深度学习
|
||||
edge2_data = {
|
||||
"relationship": "包含",
|
||||
"weight": 1.0,
|
||||
"description": "机器学习领域包含深度学习这个子领域",
|
||||
}
|
||||
print(f"插入边2: {node2_id} -> {node3_id}")
|
||||
await storage.upsert_edge(node2_id, node3_id, edge2_data)
|
||||
|
||||
# 2. 测试 node_degree - 获取节点的度数
|
||||
print(f"== 测试 node_degree: {node1_id}")
|
||||
node1_degree = await storage.node_degree(node1_id)
|
||||
print(f"节点 {node1_id} 的度数: {node1_degree}")
|
||||
assert node1_degree == 1, f"节点 {node1_id} 的度数应为1,实际为 {node1_degree}"
|
||||
|
||||
# 3. 测试 edge_degree - 获取边的度数
|
||||
print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
|
||||
edge_degree = await storage.edge_degree(node1_id, node2_id)
|
||||
print(f"边 {node1_id} -> {node2_id} 的度数: {edge_degree}")
|
||||
assert (
|
||||
edge_degree == 3
|
||||
), f"边 {node1_id} -> {node2_id} 的度数应为2,实际为 {edge_degree}"
|
||||
|
||||
# 4. 测试 get_node_edges - 获取节点的所有边
|
||||
print(f"== 测试 get_node_edges: {node2_id}")
|
||||
node2_edges = await storage.get_node_edges(node2_id)
|
||||
print(f"节点 {node2_id} 的所有边: {node2_edges}")
|
||||
assert (
|
||||
len(node2_edges) == 2
|
||||
), f"节点 {node2_id} 应有2条边,实际有 {len(node2_edges)}"
|
||||
|
||||
# 5. 测试 get_all_labels - 获取所有标签
|
||||
print("== 测试 get_all_labels")
|
||||
all_labels = await storage.get_all_labels()
|
||||
print(f"所有标签: {all_labels}")
|
||||
assert len(all_labels) == 3, f"应有3个标签,实际有 {len(all_labels)}"
|
||||
assert node1_id in all_labels, f"{node1_id} 应在标签列表中"
|
||||
assert node2_id in all_labels, f"{node2_id} 应在标签列表中"
|
||||
assert node3_id in all_labels, f"{node3_id} 应在标签列表中"
|
||||
|
||||
# 6. 测试 get_knowledge_graph - 获取知识图谱
|
||||
print("== 测试 get_knowledge_graph")
|
||||
kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
|
||||
print(f"知识图谱节点数: {len(kg.nodes)}")
|
||||
print(f"知识图谱边数: {len(kg.edges)}")
|
||||
assert isinstance(kg, KnowledgeGraph), "返回结果应为 KnowledgeGraph 类型"
|
||||
assert len(kg.nodes) == 3, f"知识图谱应有3个节点,实际有 {len(kg.nodes)}"
|
||||
assert len(kg.edges) == 2, f"知识图谱应有2条边,实际有 {len(kg.edges)}"
|
||||
|
||||
# 7. 测试 delete_node - 删除节点
|
||||
print(f"== 测试 delete_node: {node3_id}")
|
||||
await storage.delete_node(node3_id)
|
||||
node3_props = await storage.get_node(node3_id)
|
||||
print(f"删除后查询节点属性 {node3_id}: {node3_props}")
|
||||
assert node3_props is None, f"节点 {node3_id} 应已被删除"
|
||||
|
||||
# 重新插入节点3用于后续测试
|
||||
await storage.upsert_node(node3_id, node3_data)
|
||||
await storage.upsert_edge(node2_id, node3_id, edge2_data)
|
||||
|
||||
# 8. 测试 remove_edges - 删除边
|
||||
print(f"== 测试 remove_edges: {node2_id} -> {node3_id}")
|
||||
await storage.remove_edges([(node2_id, node3_id)])
|
||||
edge_props = await storage.get_edge(node2_id, node3_id)
|
||||
print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}")
|
||||
assert edge_props is None, f"边 {node2_id} -> {node3_id} 应已被删除"
|
||||
|
||||
# 9. 测试 remove_nodes - 批量删除节点
|
||||
print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]")
|
||||
await storage.remove_nodes([node2_id, node3_id])
|
||||
node2_props = await storage.get_node(node2_id)
|
||||
node3_props = await storage.get_node(node3_id)
|
||||
print(f"删除后查询节点属性 {node2_id}: {node2_props}")
|
||||
print(f"删除后查询节点属性 {node3_id}: {node3_props}")
|
||||
assert node2_props is None, f"节点 {node2_id} 应已被删除"
|
||||
assert node3_props is None, f"节点 {node3_id} 应已被删除"
|
||||
|
||||
# 10. 测试 drop - 清理数据
|
||||
print("== 测试 drop")
|
||||
result = await storage.drop()
|
||||
print(f"清理结果: {result}")
|
||||
assert (
|
||||
result["status"] == "success"
|
||||
), f"清理应成功,实际状态为 {result['status']}"
|
||||
|
||||
# 验证清理结果
|
||||
all_labels = await storage.get_all_labels()
|
||||
print(f"清理后的所有标签: {all_labels}")
|
||||
assert len(all_labels) == 0, f"清理后应没有标签,实际有 {len(all_labels)}"
|
||||
|
||||
print("\n高级测试完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 显示程序标题
|
||||
ASCIIColors.cyan("""
|
||||
╔══════════════════════════════════════════════════════════════╗
|
||||
║ 通用图存储测试程序 ║
|
||||
╚══════════════════════════════════════════════════════════════╝
|
||||
""")
|
||||
|
||||
# 检查.env文件
|
||||
if not check_env_file():
|
||||
return
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
# 获取图存储类型
|
||||
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
|
||||
ASCIIColors.magenta(f"\n当前配置的图存储类型: {graph_storage_type}")
|
||||
ASCIIColors.white(
|
||||
f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
|
||||
)
|
||||
|
||||
# 初始化存储实例
|
||||
storage = await initialize_graph_storage()
|
||||
if not storage:
|
||||
ASCIIColors.red("初始化存储实例失败,测试程序退出")
|
||||
return
|
||||
|
||||
try:
|
||||
# 显示测试选项
|
||||
ASCIIColors.yellow("\n请选择测试类型:")
|
||||
ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
|
||||
ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
|
||||
ASCIIColors.white("3. 全部测试")
|
||||
|
||||
choice = input("\n请输入选项 (1/2/3): ")
|
||||
|
||||
if choice == "1":
|
||||
await test_graph_basic(storage)
|
||||
elif choice == "2":
|
||||
await test_graph_advanced(storage)
|
||||
elif choice == "3":
|
||||
ASCIIColors.cyan("\n=== 开始基本测试 ===")
|
||||
basic_result = await test_graph_basic(storage)
|
||||
|
||||
if basic_result:
|
||||
ASCIIColors.cyan("\n=== 开始高级测试 ===")
|
||||
await test_graph_advanced(storage)
|
||||
else:
|
||||
ASCIIColors.red("无效的选项")
|
||||
|
||||
finally:
|
||||
# 关闭连接
|
||||
if storage:
|
||||
await storage.finalize()
|
||||
ASCIIColors.green("\n存储连接已关闭")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user