Merge branch 'HKUDS:main' into main

This commit is contained in:
jin
2025-01-16 09:59:27 +08:00
committed by GitHub
19 changed files with 2609 additions and 2396 deletions

171
README.md
View File

@@ -12,7 +12,7 @@
</p> </p>
<p> <p>
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' /> <img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
<img src="https://img.shields.io/badge/python->=3.10-blue"> <img src="https://img.shields.io/badge/python-3.10-blue">
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a> <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a> <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
</p> </p>
@@ -26,7 +26,8 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div> </div>
## 🎉 News ## 🎉 News
- [x] [2025.01.06]🎯📢LightRAG now supports [PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage). - [x] [2025.01.13]🎯📢Our team has released [MiniRAG](https://github.com/HKUDS/MiniRAG) making RAG simpler with small models.
- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage).
- [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
- [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise. - [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author. - [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
@@ -361,6 +362,18 @@ see test_neo4j.py for a working example.
For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac. * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
```
SET search_path = ag_catalog, "$user", public;
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
```
* Known issue of the Apache AGE: The released versions got below issue:
> You might find that the properties of the nodes/edges are empty.
> It is a known issue of the release version: https://github.com/apache/age/pull/1721
>
> You can Compile the AGE from source code and fix it.
### Insert Custom KG ### Insert Custom KG
@@ -912,12 +925,14 @@ pip install -e ".[api]"
### Prerequisites ### Prerequisites
Before running any of the servers, ensure you have the corresponding backend service running: Before running any of the servers, ensure you have the corresponding backend service running for both llm and embedding.
The new api allows you to mix different bindings for llm/embeddings.
For example, you have the possibility to use ollama for the embedding and openai for the llm.
#### For LoLLMs Server #### For LoLLMs Server
- LoLLMs must be running and accessible - LoLLMs must be running and accessible
- Default connection: http://localhost:9600 - Default connection: http://localhost:9600
- Configure using --lollms-host if running on a different host/port - Configure using --llm-binding-host and/or --embedding-binding-host if running on a different host/port
#### For Ollama Server #### For Ollama Server
- Ollama must be running and accessible - Ollama must be running and accessible
@@ -953,113 +968,96 @@ The output of the last command will give you the endpoint and the key for the Op
Each server has its own specific configuration options: Each server has its own specific configuration options:
#### LoLLMs Server Options #### LightRag Server Options
| Parameter | Default | Description |
|-----------|---------|-------------|
| --host | 0.0.0.0 | RAG server host |
| --port | 9621 | RAG server port |
| --model | mistral-nemo:latest | LLM model name |
| --embedding-model | bge-m3:latest | Embedding model name |
| --lollms-host | http://localhost:9600 | LoLLMS backend URL |
| --working-dir | ./rag_storage | Working directory for RAG |
| --max-async | 4 | Maximum async operations |
| --max-tokens | 32768 | Maximum token size |
| --embedding-dim | 1024 | Embedding dimensions |
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### Ollama Server Options
| Parameter | Default | Description |
|-----------|---------|-------------|
| --host | 0.0.0.0 | RAG server host |
| --port | 9621 | RAG server port |
| --model | mistral-nemo:latest | LLM model name |
| --embedding-model | bge-m3:latest | Embedding model name |
| --ollama-host | http://localhost:11434 | Ollama backend URL |
| --working-dir | ./rag_storage | Working directory for RAG |
| --max-async | 4 | Maximum async operations |
| --max-tokens | 32768 | Maximum token size |
| --embedding-dim | 1024 | Embedding dimensions |
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### OpenAI Server Options
| Parameter | Default | Description |
|-----------|---------|-------------|
| --host | 0.0.0.0 | RAG server host |
| --port | 9621 | RAG server port |
| --model | gpt-4 | OpenAI model name |
| --embedding-model | text-embedding-3-large | OpenAI embedding model |
| --working-dir | ./rag_storage | Working directory for RAG |
| --max-tokens | 32768 | Maximum token size |
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-dir | ./inputs | Input directory for documents |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |
#### OpenAI AZURE Server Options
| Parameter | Default | Description | | Parameter | Default | Description |
|-----------|---------|-------------| |-----------|---------|-------------|
| --host | 0.0.0.0 | Server host | | --host | 0.0.0.0 | Server host |
| --port | 9621 | Server port | | --port | 9621 | Server port |
| --model | gpt-4 | OpenAI model name | | --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai |
| --embedding-model | text-embedding-3-large | OpenAI embedding model | | --llm-binding-host | (dynamic) | LLM server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) |
| --working-dir | ./rag_storage | Working directory for RAG | | --llm-model | mistral-nemo:latest | LLM model name |
| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai |
| --embedding-binding-host | (dynamic) | Embedding server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) |
| --embedding-model | bge-m3:latest | Embedding model name |
| --working-dir | ./rag_storage | Working directory for RAG storage |
| --input-dir | ./inputs | Directory containing input documents |
| --max-async | 4 | Maximum async operations |
| --max-tokens | 32768 | Maximum token size | | --max-tokens | 32768 | Maximum token size |
| --embedding-dim | 1024 | Embedding dimensions |
| --max-embed-tokens | 8192 | Maximum embedding token size | | --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-dir | ./inputs | Input directory for documents | | --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout |
| --enable-cache | True | Enable response cache | | --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| --log-level | INFO | Logging level | | --key | None | API key for authentication. Protects lightrag server against unauthorized access |
| --key | none | Access Key to protect the lightrag service | | --ssl | False | Enable HTTPS |
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`. For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`.
### Example Usage ### Example Usage
#### LoLLMs RAG Server #### Running a Lightrag server with ollama default local server as llm and embedding backends
Ollama is the default backend for both llm and embedding, so by default you can run lightrag-server with no parameters and the default ones will be used. Make sure ollama is installed and is running and default models are already installed on ollama.
```bash ```bash
# Custom configuration with specific model and working directory # Run lightrag with ollama, mistral-nemo:latest for llm, and bge-m3:latest for embedding
lollms-lightrag-server --model mistral-nemo --port 8080 --working-dir ./custom_rag lightrag-server
# Using specific models (ensure they are installed in your LoLLMs instance) # Using specific models (ensure they are installed in your ollama instance)
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 lightrag-server --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-model nomic-embed-text --embedding-dim 1024
# Using specific models and an authentication key # Using an authentication key
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 --key ky-mykey lightrag-server --key my-key
# Using lollms for llm and ollama for embedding
lightrag-server --llm-binding lollms
``` ```
#### Ollama RAG Server #### Running a Lightrag server with lollms default local server as llm and embedding backends
```bash ```bash
# Custom configuration with specific model and working directory # Run lightrag with lollms, mistral-nemo:latest for llm, and bge-m3:latest for embedding, use lollms for both llm and embedding
ollama-lightrag-server --model mistral-nemo:latest --port 8080 --working-dir ./custom_rag lightrag-server --llm-binding lollms --embedding-binding lollms
# Using specific models (ensure they are installed in your Ollama instance) # Using specific models (ensure they are installed in your ollama instance)
ollama-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 lightrag-server --llm-binding lollms --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-binding lollms --embedding-model nomic-embed-text --embedding-dim 1024
# Using an authentication key
lightrag-server --key my-key
# Using lollms for llm and openai for embedding
lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small
``` ```
#### OpenAI RAG Server
#### Running a Lightrag server with openai server as llm and embedding backends
```bash ```bash
# Using GPT-4 with text-embedding-3-large # Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding
openai-lightrag-server --port 9624 --model gpt-4 --embedding-model text-embedding-3-large lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small
```
#### Azure OpenAI RAG Server # Using an authentication key
```bash lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small --key my-key
# Using GPT-4 with text-embedding-3-large
azure-openai-lightrag-server --model gpt-4o --port 8080 --working-dir ./custom_rag --embedding-model text-embedding-3-large # Using lollms for llm and openai for embedding
lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small
``` ```
#### Running a Lightrag server with azure openai server as llm and embedding backends
```bash
# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding
lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small
# Using an authentication key
lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding azure_openai --embedding-model text-embedding-3-small --key my-key
# Using lollms for llm and azure_openai for embedding
lightrag-server --llm-binding lollms --embedding-binding azure_openai --embedding-model text-embedding-3-small
```
**Important Notes:** **Important Notes:**
- For LoLLMs: Make sure the specified models are installed in your LoLLMs instance - For LoLLMs: Make sure the specified models are installed in your LoLLMs instance
@@ -1069,10 +1067,7 @@ azure-openai-lightrag-server --model gpt-4o --port 8080 --working-dir ./custom_r
For help on any server, use the --help flag: For help on any server, use the --help flag:
```bash ```bash
lollms-lightrag-server --help lightrag-server --help
ollama-lightrag-server --help
openai-lightrag-server --help
azure-openai-lightrag-server --help
``` ```
Note: If you don't need the API functionality, you can install the base package without API support using: Note: If you don't need the API functionality, you can install the base package without API support using:
@@ -1092,7 +1087,7 @@ Query the RAG system with options for different search modes.
```bash ```bash
curl -X POST "http://localhost:9621/query" \ curl -X POST "http://localhost:9621/query" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{"query": "Your question here", "mode": "hybrid"}' -d '{"query": "Your question here", "mode": "hybrid", ""}'
``` ```
#### POST /query/stream #### POST /query/stream

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.1.0" __version__ = "1.1.1"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -1,532 +0,0 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel
import asyncio
import logging
import argparse
from lightrag import LightRAG, QueryParam
from lightrag.llm import (
azure_openai_complete_if_cache,
azure_openai_embedding,
)
from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import os
from dotenv import load_dotenv
import inspect
import json
from fastapi.responses import StreamingResponse
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
def parse_args():
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with OpenAI integration"
)
# Server configuration
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# Model configuration
parser.add_argument(
"--model", default="gpt-4o", help="OpenAI model name (default: gpt-4o)"
)
parser.add_argument(
"--embedding-model",
default="text-embedding-3-large",
help="OpenAI embedding model (default: text-embedding-3-large)",
)
# RAG configuration
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
parser.add_argument(
"--enable-cache",
default=True,
help="Enable response cache (default: True)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
class DocumentManager:
"""Handles document operations and tracking"""
def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
"""Mark a file as indexed"""
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
"""Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
only_need_context: bool = False
# stream: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."]
embedding = await azure_openai_embedding(test_text, model=embedding_model)
return embedding.shape[1]
def create_app(args):
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
# Get embedding dimensions
embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model))
async def async_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
):
"""Async wrapper for OpenAI completion"""
kwargs.pop("keyword_extraction", None)
return await azure_openai_complete_if_cache(
args.model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=AZURE_OPENAI_ENDPOINT,
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_OPENAI_API_VERSION,
**kwargs,
)
# Initialize RAG with OpenAI configuration
rag = LightRAG(
enable_llm_cache=args.enable_cache,
working_dir=args.working_dir,
llm_model_func=async_openai_complete,
llm_model_name=args.model,
llm_model_max_token_size=args.max_tokens,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: azure_openai_embedding(
texts, model=args.embedding_model
),
),
)
@app.on_event("startup")
async def startup_event():
"""Index all files in input directory during startup"""
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
try:
# Use async file reading
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read()
# Use the async version of insert directly
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
logging.info(f"Indexed file: {file_path}")
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
new_files = doc_manager.scan_directory()
indexed_count = 0
for file_path in new_files:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
indexed_count += 1
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
return {
"status": "success",
"indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/resetcache", dependencies=[Depends(optional_api_key)])
async def reset_cache():
"""Manually reset cache"""
try:
cachefile = args.working_dir + "/kv_store_llm_response_cache.json"
if os.path.exists(cachefile):
with open(cachefile, "w") as f:
f.write("{}")
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=False,
only_need_context=request.only_need_context,
),
)
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
),
)
if inspect.isasyncgen(response):
async def stream_generator():
async for chunk in response:
yield json.dumps({"data": chunk}) + "\n"
return StreamingResponse(
stream_generator(), media_type="application/json"
)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
await rag.ainsert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=1,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
rag.insert(text)
else:
raise HTTPException(
status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported",
)
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
rag.insert(text)
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
status_message = f"Successfully inserted {inserted_count} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status="success" if inserted_count > 0 else "partial_success",
message=status_message,
document_count=len(files),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": len(doc_manager.indexed_files),
"configuration": {
"model": args.model,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"embedding_dim": embedding_dim,
},
}
return app
def main():
args = parse_args()
import uvicorn
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,842 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel
import logging
import argparse
from lightrag import LightRAG, QueryParam
from lightrag.llm import lollms_model_complete, lollms_embed
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
from lightrag.utils import EmbeddingFunc
from typing import Optional, List, Union
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": "http://localhost:11434",
"lollms": "http://localhost:9600",
"azure_openai": "https://api.openai.com/v1",
"openai": "https://api.openai.com/v1",
}
return default_hosts.get(
binding_type, "http://localhost:11434"
) # fallback to ollama if unknown
def parse_args():
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Start by the bindings
parser.add_argument(
"--llm-binding",
default="ollama",
help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)",
)
parser.add_argument(
"--embedding-binding",
default="ollama",
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)",
)
# Parse just these arguments first
temp_args, _ = parser.parse_known_args()
# Add remaining arguments with dynamic defaults for hosts
# Server configuration
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# LLM Model configuration
default_llm_host = get_default_host(temp_args.llm_binding)
parser.add_argument(
"--llm-binding-host",
default=default_llm_host,
help=f"llm server host URL (default: {default_llm_host})",
)
parser.add_argument(
"--llm-model",
default="mistral-nemo:latest",
help="LLM model name (default: mistral-nemo:latest)",
)
# Embedding model configuration
default_embedding_host = get_default_host(temp_args.embedding_binding)
parser.add_argument(
"--embedding-binding-host",
default=default_embedding_host,
help=f"embedding server host URL (default: {default_embedding_host})",
)
parser.add_argument(
"--embedding-model",
default="bge-m3:latest",
help="Embedding model name (default: bge-m3:latest)",
)
def timeout_type(value):
if value is None or value == "None":
return None
return int(value)
parser.add_argument(
"--timeout",
default=None,
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
)
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=1024,
help="Embedding dimensions (default: 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
# Optional https parameters
parser.add_argument(
"--ssl", action="store_true", help="Enable HTTPS (default: False)"
)
parser.add_argument(
"--ssl-certfile",
default=None,
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=None,
help="Path to SSL private key file (required if --ssl is enabled)",
)
return parser.parse_args()
class DocumentManager:
"""Handles document operations and tracking"""
def __init__(
self,
input_dir: str,
supported_extensions: tuple = (".txt", ".md", ".pdf", ".docx", ".pptx"),
):
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
"""Mark a file as indexed"""
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
"""Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args):
# Verify that bindings arer correctly setup
if args.llm_binding not in ["lollms", "ollama", "openai"]:
raise Exception("llm binding not supported")
if args.embedding_binding not in ["lollms", "ollama", "openai"]:
raise Exception("embedding binding not supported")
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.2",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete
if args.llm_binding == "lollms"
else ollama_model_complete
if args.llm_binding == "ollama"
else azure_openai_complete_if_cache
if args.llm_binding == "azure_openai"
else openai_complete_if_cache,
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
},
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.llm_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.llm_binding == "ollama"
else azure_openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
)
if args.llm_binding == "azure_openai"
else openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
),
),
)
async def index_file(file_path: Union[str, Path]) -> None:
"""Index all files inside the folder with support for multiple file formats
Args:
file_path: Path to the file to be indexed (str or Path object)
Raises:
ValueError: If file format is not supported
FileNotFoundError: If file doesn't exist
"""
if not pm.is_installed("aiofiles"):
pm.install("aiofiles")
# Convert to Path object if string
file_path = Path(file_path)
# Check if file exists
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
content = ""
# Get file extension in lowercase
ext = file_path.suffix.lower()
match ext:
case ".txt" | ".md":
# Text files handling
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read()
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from pypdf2 import PdfReader
# PDF handling
reader = PdfReader(str(file_path))
content = ""
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
# Word document handling
doc = Document(file_path)
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation
# PowerPoint handling
prs = Presentation(file_path)
content = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case _:
raise ValueError(f"Unsupported file format: {ext}")
# Insert content into RAG system
if content:
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
logging.info(f"Successfully indexed file: {file_path}")
else:
logging.warning(f"No content extracted from file: {file_path}")
@app.on_event("startup")
async def startup_event():
"""Index all files in input directory during startup"""
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
try:
await index_file(file_path)
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
new_files = doc_manager.scan_directory()
indexed_count = 0
for file_path in new_files:
try:
await index_file(file_path)
indexed_count += 1
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
return {
"status": "success",
"indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file
await index_file(file_path)
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
),
)
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
),
)
async def stream_generator():
async for chunk in response:
yield chunk
return stream_generator()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
await rag.ainsert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=1,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
"""Insert a file directly into the RAG system
Args:
file: Uploaded file
description: Optional description of the file
Returns:
InsertResponse: Status of the insertion operation
Raises:
HTTPException: For unsupported file types or processing errors
"""
try:
content = ""
# Get file extension in lowercase
ext = Path(file.filename).suffix.lower()
match ext:
case ".txt" | ".md":
# Text files handling
text_content = await file.read()
content = text_content.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from pypdf2 import PdfReader
from io import BytesIO
# Read PDF from memory
pdf_content = await file.read()
pdf_file = BytesIO(pdf_content)
reader = PdfReader(pdf_file)
content = ""
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
# Read DOCX from memory
docx_content = await file.read()
docx_file = BytesIO(docx_content)
doc = Document(docx_file)
content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation
from io import BytesIO
# Read PPTX from memory
pptx_content = await file.read()
pptx_file = BytesIO(pptx_content)
prs = Presentation(pptx_file)
content = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case _:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
# Insert content into RAG system
if content:
# Add description if provided
if description:
content = f"{description}\n\n{content}"
await rag.ainsert(content)
logging.info(f"Successfully indexed file: {file.filename}")
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
else:
raise HTTPException(
status_code=400,
detail="No content could be extracted from the file",
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
logging.error(f"Error processing file {file.filename}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
"""Process multiple files in batch mode
Args:
files: List of files to process
Returns:
InsertResponse: Status of the batch insertion operation
Raises:
HTTPException: For processing errors
"""
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = ""
ext = Path(file.filename).suffix.lower()
match ext:
case ".txt" | ".md":
text_content = await file.read()
content = text_content.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from pypdf2 import PdfReader
from io import BytesIO
pdf_content = await file.read()
pdf_file = BytesIO(pdf_content)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
docx_content = await file.read()
docx_file = BytesIO(docx_content)
doc = Document(docx_file)
content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation
from io import BytesIO
pptx_content = await file.read()
pptx_file = BytesIO(pptx_content)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case _:
failed_files.append(f"{file.filename} (unsupported type)")
continue
if content:
await rag.ainsert(content)
inserted_count += 1
logging.info(f"Successfully indexed file: {file.filename}")
else:
failed_files.append(f"{file.filename} (no content extracted)")
except UnicodeDecodeError:
failed_files.append(f"{file.filename} (encoding error)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
logging.error(f"Error processing file {file.filename}: {str(e)}")
# Prepare status message
if inserted_count == len(files):
status = "success"
status_message = f"Successfully inserted all {inserted_count} documents"
elif inserted_count > 0:
status = "partial_success"
status_message = f"Successfully inserted {inserted_count} out of {len(files)} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
else:
status = "failure"
status_message = "No documents were successfully inserted"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status=status,
message=status_message,
document_count=inserted_count,
)
except Exception as e:
logging.error(f"Batch processing error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": len(doc_manager.indexed_files),
"configuration": {
# LLM configuration binding/host address (if applicable)/model (if applicable)
"llm_binding": args.llm_binding,
"llm_binding_host": args.llm_binding_host,
"llm_model": args.llm_model,
# embedding model configuration binding/host address (if applicable)/model (if applicable)
"embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
},
}
return app
def main():
args = parse_args()
import uvicorn
app = create_app(args)
uvicorn_config = {
"app": app,
"host": args.host,
"port": args.port,
}
if args.ssl:
uvicorn_config.update(
{
"ssl_certfile": args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile,
}
)
uvicorn.run(**uvicorn_config)
if __name__ == "__main__":
main()

View File

@@ -1,492 +0,0 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel
import logging
import argparse
from lightrag import LightRAG, QueryParam
from lightrag.llm import lollms_model_complete, lollms_embed
from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
def parse_args():
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# Model configuration
parser.add_argument(
"--model",
default="mistral-nemo:latest",
help="LLM model name (default: mistral-nemo:latest)",
)
parser.add_argument(
"--embedding-model",
default="bge-m3:latest",
help="Embedding model name (default: bge-m3:latest)",
)
parser.add_argument(
"--lollms-host",
default="http://localhost:9600",
help="lollms host URL (default: http://localhost:9600)",
)
# RAG configuration
parser.add_argument(
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
)
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=1024,
help="Embedding dimensions (default: 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
class DocumentManager:
"""Handles document operations and tracking"""
def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
"""Mark a file as indexed"""
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
"""Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args):
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete,
llm_model_name=args.model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={
"host": args.lollms_host,
"options": {"num_ctx": args.max_tokens},
},
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts, embed_model=args.embedding_model, host=args.lollms_host
),
),
)
@app.on_event("startup")
async def startup_event():
"""Index all files in input directory during startup"""
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
try:
# Use async file reading
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read()
# Use the async version of insert directly
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
logging.info(f"Indexed file: {file_path}")
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
new_files = doc_manager.scan_directory()
indexed_count = 0
for file_path in new_files:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
indexed_count += 1
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
return {
"status": "success",
"indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
),
)
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
),
)
async def stream_generator():
async for chunk in response:
yield chunk
return stream_generator()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
rag.insert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=len(rag),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
await rag.ainsert(text)
else:
raise HTTPException(
status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported",
)
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
await rag.ainsert(text)
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
status_message = f"Successfully inserted {inserted_count} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status="success" if inserted_count > 0 else "partial_success",
message=status_message,
document_count=len(files),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": len(doc_manager.indexed_files),
"configuration": {
"model": args.model,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"lollms_host": args.lollms_host,
},
}
return app
def main():
args = parse_args()
import uvicorn
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()

View File

@@ -1,491 +0,0 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel
import logging
import argparse
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
def parse_args():
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# Model configuration
parser.add_argument(
"--model",
default="mistral-nemo:latest",
help="LLM model name (default: mistral-nemo:latest)",
)
parser.add_argument(
"--embedding-model",
default="bge-m3:latest",
help="Embedding model name (default: bge-m3:latest)",
)
parser.add_argument(
"--ollama-host",
default="http://localhost:11434",
help="Ollama host URL (default: http://localhost:11434)",
)
# RAG configuration
parser.add_argument(
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
)
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=1024,
help="Embedding dimensions (default: 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
class DocumentManager:
"""Handles document operations and tracking"""
def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
"""Mark a file as indexed"""
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
"""Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args):
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=ollama_model_complete,
llm_model_name=args.model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={
"host": args.ollama_host,
"options": {"num_ctx": args.max_tokens},
},
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: ollama_embed(
texts, embed_model=args.embedding_model, host=args.ollama_host
),
),
)
@app.on_event("startup")
async def startup_event():
"""Index all files in input directory during startup"""
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
try:
# Use async file reading
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read()
# Use the async version of insert directly
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
logging.info(f"Indexed file: {file_path}")
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
new_files = doc_manager.scan_directory()
indexed_count = 0
for file_path in new_files:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
indexed_count += 1
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
return {
"status": "success",
"indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
),
)
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
),
)
async def stream_generator():
async for chunk in response:
yield chunk
return stream_generator()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
await rag.ainsert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=len(rag),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
await rag.ainsert(text)
else:
raise HTTPException(
status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported",
)
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
await rag.ainsert(text)
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
status_message = f"Successfully inserted {inserted_count} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status="success" if inserted_count > 0 else "partial_success",
message=status_message,
document_count=len(files),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": len(doc_manager.indexed_files),
"configuration": {
"model": args.model,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"ollama_host": args.ollama_host,
},
}
return app
def main():
args = parse_args()
import uvicorn
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()

View File

@@ -1,506 +0,0 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel
import asyncio
import logging
import argparse
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import nest_asyncio
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
def parse_args():
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with OpenAI integration"
)
# Server configuration
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# Model configuration
parser.add_argument(
"--model", default="gpt-4", help="OpenAI model name (default: gpt-4)"
)
parser.add_argument(
"--embedding-model",
default="text-embedding-3-large",
help="OpenAI embedding model (default: text-embedding-3-large)",
)
# RAG configuration
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
return parser.parse_args()
class DocumentManager:
"""Handles document operations and tracking"""
def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
"""Mark a file as indexed"""
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
"""Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."]
embedding = await openai_embedding(test_text, model=embedding_model)
return embedding.shape[1]
def create_app(args):
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
# Get embedding dimensions
embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model))
async def async_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
):
"""Async wrapper for OpenAI completion"""
return await openai_complete_if_cache(
args.model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
# Initialize RAG with OpenAI configuration
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=async_openai_complete,
llm_model_name=args.model,
llm_model_max_token_size=args.max_tokens,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: openai_embedding(texts, model=args.embedding_model),
),
)
@app.on_event("startup")
async def startup_event():
"""Index all files in input directory during startup"""
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
try:
# Use async file reading
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read()
# Use the async version of insert directly
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
logging.info(f"Indexed file: {file_path}")
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
new_files = doc_manager.scan_directory()
indexed_count = 0
for file_path in new_files:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
rag.insert(content)
doc_manager.mark_as_indexed(file_path)
indexed_count += 1
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
return {
"status": "success",
"indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
rag.insert(content)
doc_manager.mark_as_indexed(file_path)
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
),
)
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
),
)
async def stream_generator():
async for chunk in response:
yield chunk
return stream_generator()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
rag.insert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=len(rag),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
rag.insert(text)
else:
raise HTTPException(
status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported",
)
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
rag.insert(text)
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
status_message = f"Successfully inserted {inserted_count} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status="success" if inserted_count > 0 else "partial_success",
message=status_message,
document_count=len(files),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": len(doc_manager.indexed_files),
"configuration": {
"model": args.model,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"embedding_dim": embedding_dim,
},
}
return app
def main():
args = parse_args()
import uvicorn
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()

View File

@@ -7,6 +7,7 @@ nest_asyncio
numpy numpy
ollama ollama
openai openai
pipmaster
python-dotenv python-dotenv
python-multipart python-multipart
tenacity tenacity

View File

@@ -2,7 +2,7 @@ import os
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
from pymongo import MongoClient from pymongo import MongoClient
from typing import Union
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseKVStorage from lightrag.base import BaseKVStorage
@@ -41,11 +41,35 @@ class MongoKVStorage(BaseKVStorage):
return set([s for s in data if s not in existing_ids]) return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
if self.namespace == "llm_response_cache":
for mode, items in data.items():
for k, v in tqdm_async(items.items(), desc="Upserting"):
key = f"{mode}_{k}"
result = self._data.update_one(
{"_id": key}, {"$setOnInsert": v}, upsert=True
)
if result.upserted_id:
logger.debug(f"\nInserted new document with key: {key}")
data[mode][k]["_id"] = key
else:
for k, v in tqdm_async(data.items(), desc="Upserting"): for k, v in tqdm_async(data.items(), desc="Upserting"):
self._data.update_one({"_id": k}, {"$set": v}, upsert=True) self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k data[k]["_id"] = k
return data return data
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if "llm_response_cache" == self.namespace:
res = {}
v = self._data.find_one({"_id": mode + "_" + id})
if v:
res[id] = v
logger.debug(f"llm_response_cache find one by:{id}")
return res
else:
return None
else:
return None
async def drop(self): async def drop(self):
""" """ """ """
pass pass

View File

@@ -39,6 +39,7 @@ class Neo4JStorage(BaseGraphStorage):
URI = os.environ["NEO4J_URI"] URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"] USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"] PASSWORD = os.environ["NEO4J_PASSWORD"]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE" "NEO4J_DATABASE"
) # If this param is None, the home database will be used. If it is not None, the specified database will be used. ) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
@@ -47,7 +48,11 @@ class Neo4JStorage(BaseGraphStorage):
URI, auth=(USERNAME, PASSWORD) URI, auth=(USERNAME, PASSWORD)
) )
_database_name = "home database" if DATABASE is None else f"database {DATABASE}" _database_name = "home database" if DATABASE is None else f"database {DATABASE}"
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver: with GraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
) as _sync_driver:
try: try:
with _sync_driver.session(database=DATABASE) as session: with _sync_driver.session(database=DATABASE) as session:
try: try:

View File

@@ -130,6 +130,7 @@ class PostgreSQLDB:
data: Union[list, dict] = None, data: Union[list, dict] = None,
for_age: bool = False, for_age: bool = False,
graph_name: str = None, graph_name: str = None,
upsert: bool = False,
): ):
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@@ -140,8 +141,16 @@ class PostgreSQLDB:
await connection.execute(sql) await connection.execute(sql)
else: else:
await connection.execute(sql, *data.values()) await connection.execute(sql, *data.values())
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
) as e:
if upsert:
print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database error: {e}") logger.error(f"PostgreSQL database error: {e.__class__} - {e}")
print(sql) print(sql)
print(data) print(data)
raise raise
@@ -568,10 +577,10 @@ class PGGraphStorage(BaseGraphStorage):
if dtype == "vertex": if dtype == "vertex":
vertex = json.loads(v) vertex = json.loads(v)
field = json.loads(v).get("properties") field = vertex.get("properties")
if not field: if not field:
field = {} field = {}
field["label"] = PGGraphStorage._decode_graph_label(vertex["label"]) field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
d[k] = field d[k] = field
# convert edge from id-label->id by replacing id with node information # convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query # we only do this if the vertex was also returned in the query
@@ -666,73 +675,8 @@ class PGGraphStorage(BaseGraphStorage):
# otherwise return the value stripping out some common special chars # otherwise return the value stripping out some common special chars
return field.replace("(", "_").replace(")", "") return field.replace("(", "_").replace(")", "")
@staticmethod
def _wrap_query(query: str, graph_name: str, **params: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
params (dict): parameters for the query
Returns:
str: an equivalent pgsql query
"""
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields})"""
# if there are any returned fields they must be added to the pgsql query
if "return" in query.lower():
# parse return statement to identify returned fields
fields = (
query.lower()
.split("return")[-1]
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
)
# get pgsql formatted field names
fields = [
PGGraphStorage._get_col_name(field, idx)
for idx, field in enumerate(fields)
]
# build resulting pgsql relation
fields_str = ", ".join(
[field.split(".")[-1] + " agtype" for field in fields]
)
# if no return statement we still need to return a single field of type agtype
else:
fields_str = "a agtype"
select_str = "*"
return template.format(
graph_name=graph_name,
query=query.format(**params),
fields=fields_str,
projection=select_str,
)
async def _query( async def _query(
self, query: str, readonly=True, upsert_edge=False, **params: str self, query: str, readonly: bool = True, upsert: bool = False
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
@@ -746,7 +690,7 @@ class PGGraphStorage(BaseGraphStorage):
List[Dict[str, Any]]: a list of dictionaries containing the result set List[Dict[str, Any]]: a list of dictionaries containing the result set
""" """
# convert cypher query to pgsql/age query # convert cypher query to pgsql/age query
wrapped_query = self._wrap_query(query, self.graph_name, **params) wrapped_query = query
# execute the query, rolling back on an error # execute the query, rolling back on an error
try: try:
@@ -758,22 +702,16 @@ class PGGraphStorage(BaseGraphStorage):
graph_name=self.graph_name, graph_name=self.graph_name,
) )
else: else:
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
if upsert_edge:
data = await self.db.execute( data = await self.db.execute(
f"{wrapped_query};{wrapped_query};", wrapped_query,
for_age=True, for_age=True,
graph_name=self.graph_name, graph_name=self.graph_name,
) upsert=upsert,
else:
data = await self.db.execute(
wrapped_query, for_age=True, graph_name=self.graph_name
) )
except Exception as e: except Exception as e:
raise PGGraphQueryException( raise PGGraphQueryException(
{ {
"message": f"Error executing graph query: {query.format(**params)}", "message": f"Error executing graph query: {query}",
"wrapped": wrapped_query, "wrapped": wrapped_query,
"detail": str(e), "detail": str(e),
} }
@@ -788,77 +726,85 @@ class PGGraphStorage(BaseGraphStorage):
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"') entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists""" query = """SELECT * FROM cypher('%s', $$
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} MATCH (n:Entity {node_id: "%s"})
single_result = (await self._query(query, **params))[0] RETURN count(n) > 0 AS node_exists
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
single_result["node_exists"], single_result["node_exists"],
) )
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"') src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
entity_name_label_target = target_node_id.strip('"') tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) query = """SELECT * FROM cypher('%s', $$
RETURN COUNT(r) > 0 AS edge_exists""" MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
params = { RETURN COUNT(r) > 0 AS edge_exists
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), $$) AS (edge_exists bool)""" % (
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), self.graph_name,
} src_label,
single_result = (await self._query(query, **params))[0] tgt_label,
)
single_result = (await self._query(query))[0]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
single_result["edge_exists"], single_result["edge_exists"],
) )
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = node_id.strip('"') label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`) RETURN n""" query = """SELECT * FROM cypher('%s', $$
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} MATCH (n:Entity {node_id: "%s"})
record = await self._query(query, **params) RETURN n
$$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query)
if record: if record:
node = record[0] node = record[0]
node_dict = node["n"] node_dict = node["n"]
logger.debug( logger.debug(
"{%s}: query: {%s}, result: {%s}", "{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
node_dict, node_dict,
) )
return node_dict return node_dict
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"') label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count""" query = """SELECT * FROM cypher('%s', $$
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} MATCH (n:Entity {node_id: "%s"})-[]->(x)
record = (await self._query(query, **params))[0] RETURN count(x) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0]
if record: if record:
edge_count = int(record["total_edge_count"]) edge_count = int(record["total_edge_count"])
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
edge_count, edge_count,
) )
return edge_count return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('"') src_degree = await self.node_degree(src_id)
entity_name_label_target = tgt_id.strip('"') trg_degree = await self.node_degree(tgt_id)
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition # Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree src_degree = 0 if src_degree is None else src_degree
@@ -885,23 +831,25 @@ class PGGraphStorage(BaseGraphStorage):
Returns: Returns:
list: List of all relationships/edges found list: List of all relationships/edges found
""" """
entity_name_label_source = source_node_id.strip('"') src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
entity_name_label_target = target_node_id.strip('"') tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
RETURN properties(r) as edge_properties RETURN properties(r) as edge_properties
LIMIT 1""" LIMIT 1
params = { $$) AS (edge_properties agtype)""" % (
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), self.graph_name,
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), src_label,
} tgt_label,
record = await self._query(query, **params) )
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]: if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"] result = record[0]["edge_properties"]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
result, result,
) )
return result return result
@@ -911,29 +859,41 @@ class PGGraphStorage(BaseGraphStorage):
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
""" """
node_label = source_node_id.strip('"') label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
query = """MATCH (n:`{label}`) query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected""" RETURN n, r, connected
params = {"label": PGGraphStorage._encode_graph_label(node_label)} $$) AS (n agtype, r agtype, connected agtype)""" % (
results = await self._query(query, **params) self.graph_name,
label,
)
results = await self._query(query)
edges = [] edges = []
for record in results: for record in results:
source_node = record["n"] if record["n"] else None source_node = record["n"] if record["n"] else None
connected_node = record["connected"] if record["connected"] else None connected_node = record["connected"] if record["connected"] else None
source_label = ( source_label = (
source_node["label"] if source_node and source_node["label"] else None source_node["node_id"]
if source_node and source_node["node_id"]
else None
) )
target_label = ( target_label = (
connected_node["label"] connected_node["node_id"]
if connected_node and connected_node["label"] if connected_node and connected_node["node_id"]
else None else None
) )
if source_label and target_label: if source_label and target_label:
edges.append((source_label, target_label)) edges.append(
(
PGGraphStorage._decode_graph_label(source_label),
PGGraphStorage._decode_graph_label(target_label),
)
)
return edges return edges
@@ -950,17 +910,21 @@ class PGGraphStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
label = node_id.strip('"') label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
properties = node_data properties = node_data
query = """MERGE (n:`{label}`) query = """SELECT * FROM cypher('%s', $$
SET n += {properties}""" MERGE (n:Entity {node_id: "%s"})
params = { SET n += %s
"label": PGGraphStorage._encode_graph_label(label), RETURN n
"properties": PGGraphStorage._format_properties(properties), $$) AS (n agtype)""" % (
} self.graph_name,
label,
PGGraphStorage._format_properties(properties),
)
try: try:
await self._query(query, readonly=False, **params) await self._query(query, readonly=False, upsert=True)
logger.debug( logger.debug(
"Upserted node with label '{%s}' and properties: {%s}", "Upserted node with label '{%s}' and properties: {%s}",
label, label,
@@ -986,28 +950,30 @@ class PGGraphStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
""" """
source_node_label = source_node_id.strip('"') src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
target_node_label = target_node_id.strip('"') tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
edge_properties = edge_data edge_properties = edge_data
query = """MATCH (source:`{src_label}`) query = """SELECT * FROM cypher('%s', $$
MATCH (source:Entity {node_id: "%s"})
WITH source WITH source
MATCH (target:`{tgt_label}`) MATCH (target:Entity {node_id: "%s"})
MERGE (source)-[r:DIRECTED]->(target) MERGE (source)-[r:DIRECTED]->(target)
SET r += {properties} SET r += %s
RETURN r""" RETURN r
params = { $$) AS (r agtype)""" % (
"src_label": PGGraphStorage._encode_graph_label(source_node_label), self.graph_name,
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label), src_label,
"properties": PGGraphStorage._format_properties(edge_properties), tgt_label,
} PGGraphStorage._format_properties(edge_properties),
)
# logger.info(f"-- inserting edge after formatted: {params}") # logger.info(f"-- inserting edge after formatted: {params}")
try: try:
await self._query(query, readonly=False, upsert_edge=True, **params) await self._query(query, readonly=False, upsert=True)
logger.debug( logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}", "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
source_node_label, src_label,
target_node_label, tgt_label,
edge_properties, edge_properties,
) )
except Exception as e: except Exception as e:

View File

@@ -61,7 +61,7 @@ db = PostgreSQLDB(
"port": 15432, "port": 15432,
"user": "rag", "user": "rag",
"password": "rag", "password": "rag",
"database": "rag", "database": "r1",
} }
) )
@@ -74,8 +74,12 @@ async def query_with_age():
embedding_func=None, embedding_func=None,
) )
graph.db = db graph.db = db
res = await graph.get_node('"CHRISTMAS-TIME"') res = await graph.get_node('"A CHRISTMAS CAROL"')
print("Node is: ", res) print("Node is: ", res)
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
print("Edge is: ", res)
res = await graph.get_node_edges('"SCROOGE"')
print("Node Edges are: ", res)
async def create_edge_with_age(): async def create_edge_with_age():

View File

@@ -45,6 +45,7 @@ from .storage import (
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
# future KG integrations # future KG integrations
# from .kg.ArangoDB_impl import ( # from .kg.ArangoDB_impl import (
@@ -187,6 +188,10 @@ class LightRAG:
# Add new field for document status storage type # Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage") doc_status_storage: str = field(default="JsonDocStatusStorage")
# Custom Chunking Function
chunking_func: callable = chunking_by_token_size
chunking_func_kwargs: dict = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
log_file = os.path.join("lightrag.log") log_file = os.path.join("lightrag.log")
set_logger(log_file) set_logger(log_file)
@@ -315,15 +320,25 @@ class LightRAG:
"JsonDocStatusStorage": JsonDocStatusStorage, "JsonDocStatusStorage": JsonDocStatusStorage,
} }
def insert(self, string_or_strings): def insert(
self, string_or_strings, split_by_character=None, split_by_character_only=False
):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings)) return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
)
async def ainsert(self, string_or_strings): async def ainsert(
self, string_or_strings, split_by_character=None, split_by_character_only=False
):
"""Insert documents with checkpoint support """Insert documents with checkpoint support
Args: Args:
string_or_strings: Single document string or list of document strings string_or_strings: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
""" """
if isinstance(string_or_strings, str): if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings] string_or_strings = [string_or_strings]
@@ -379,11 +394,14 @@ class LightRAG:
**dp, **dp,
"full_doc_id": doc_id, "full_doc_id": doc_id,
} }
for dp in chunking_by_token_size( for dp in self.chunking_func(
doc["content"], doc["content"],
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=self.chunk_overlap_token_size, overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size, max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name, tiktoken_model=self.tiktoken_model_name,
**self.chunking_func_kwargs,
) )
} }
@@ -455,6 +473,73 @@ class LightRAG:
# Ensure all indexes are updated after each document # Ensure all indexes are updated after each document
await self._insert_done() await self._insert_done()
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.ainsert_custom_chunks(full_text, text_chunks)
)
async def ainsert_custom_chunks(self, full_text: str, text_chunks: list[str]):
update_storage = False
try:
doc_key = compute_mdhash_id(full_text.strip(), prefix="doc-")
new_docs = {doc_key: {"content": full_text.strip()}}
_add_doc_keys = await self.full_docs.filter_keys([doc_key])
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("This document is already in the storage.")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
inserting_chunks = {}
for chunk_text in text_chunks:
chunk_text_stripped = chunk_text.strip()
chunk_key = compute_mdhash_id(chunk_text_stripped, prefix="chunk-")
inserting_chunks[chunk_key] = {
"content": chunk_text_stripped,
"full_doc_id": doc_key,
}
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage.")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
await self.chunks_vdb.upsert(inserting_chunks)
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
else:
self.chunk_entity_relation_graph = maybe_new_kg
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
if update_storage:
await self._insert_done()
async def _insert_done(self): async def _insert_done(self):
tasks = [] tasks = []
for storage_inst in [ for storage_inst in [

View File

@@ -406,8 +406,9 @@ async def lollms_model_if_cache(
full_prompt += prompt full_prompt += prompt
request_data["prompt"] = full_prompt request_data["prompt"] = full_prompt
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(timeout=timeout) as session:
if stream: if stream:
async def inner(): async def inner():

View File

@@ -4,7 +4,6 @@ import re
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from typing import Union from typing import Union
from collections import Counter, defaultdict from collections import Counter, defaultdict
import warnings
from .utils import ( from .utils import (
logger, logger,
clean_str, clean_str,
@@ -34,10 +33,48 @@ import time
def chunking_by_token_size( def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" content: str,
split_by_character=None,
split_by_character_only=False,
overlap_token_size=128,
max_token_size=1024,
tiktoken_model="gpt-4o",
**kwargs,
): ):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = [] results = []
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
if len(_tokens) > max_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
)
else:
new_chunks.append((len(_tokens), chunk))
for index, (_len, chunk) in enumerate(new_chunks):
results.append(
{
"tokens": _len,
"content": chunk.strip(),
"chunk_order_index": index,
}
)
else:
for index, start in enumerate( for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size) range(0, len(tokens), max_token_size - overlap_token_size)
): ):
@@ -582,15 +619,22 @@ async def kg_query(
logger.warning("low_level_keywords and high_level_keywords is empty") logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
if ll_keywords == [] and query_param.mode in ["local", "hybrid"]: if ll_keywords == [] and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty") logger.warning(
return PROMPTS["fail_response"] "low_level_keywords is empty, switching from %s mode to global mode",
else: query_param.mode,
ll_keywords = ", ".join(ll_keywords) )
query_param.mode = "global"
if hl_keywords == [] and query_param.mode in ["global", "hybrid"]: if hl_keywords == [] and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty") logger.warning(
return PROMPTS["fail_response"] "high_level_keywords is empty, switching from %s mode to local mode",
else: query_param.mode,
hl_keywords = ", ".join(hl_keywords) )
query_param.mode = "local"
ll_keywords = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords = ", ".join(hl_keywords) if hl_keywords else ""
logger.info("Using %s mode for query processing", query_param.mode)
# Build context # Build context
keywords = [ll_keywords, hl_keywords] keywords = [ll_keywords, hl_keywords]
@@ -656,78 +700,52 @@ async def _build_query_context(
# ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" # ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
# hl_entities_context, hl_relations_context, hl_text_units_context = "", "", "" # hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
ll_kewwords, hl_keywrds = query[0], query[1] ll_keywords, hl_keywords = query[0], query[1]
if query_param.mode in ["local", "hybrid"]:
if ll_kewwords == "": if query_param.mode == "local":
ll_entities_context, ll_relations_context, ll_text_units_context = ( entities_context, relations_context, text_units_context = await _get_node_data(
"", ll_keywords,
"",
"",
)
warnings.warn(
"Low Level context is None. Return empty Low entity/relationship/source"
)
query_param.mode = "global"
else:
(
ll_entities_context,
ll_relations_context,
ll_text_units_context,
) = await _get_node_data(
ll_kewwords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
if query_param.mode in ["global", "hybrid"]: elif query_param.mode == "global":
if hl_keywrds == "": entities_context, relations_context, text_units_context = await _get_edge_data(
hl_entities_context, hl_relations_context, hl_text_units_context = ( hl_keywords,
"", knowledge_graph_inst,
"", relationships_vdb,
"", text_chunks_db,
) query_param,
warnings.warn( )
"High Level context is None. Return empty High entity/relationship/source" else: # hybrid mode
) (
query_param.mode = "local" ll_entities_context,
else: ll_relations_context,
( ll_text_units_context,
hl_entities_context, ) = await _get_node_data(
hl_relations_context, ll_keywords,
hl_text_units_context, knowledge_graph_inst,
) = await _get_edge_data( entities_vdb,
hl_keywrds, text_chunks_db,
query_param,
)
(
hl_entities_context,
hl_relations_context,
hl_text_units_context,
) = await _get_edge_data(
hl_keywords,
knowledge_graph_inst, knowledge_graph_inst,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
if (
hl_entities_context == ""
and hl_relations_context == ""
and hl_text_units_context == ""
):
logger.warn("No high level context found. Switching to local mode.")
query_param.mode = "local"
if query_param.mode == "hybrid":
entities_context, relations_context, text_units_context = combine_contexts( entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context], [hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context], [hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context], [hl_text_units_context, ll_text_units_context],
) )
elif query_param.mode == "local":
entities_context, relations_context, text_units_context = (
ll_entities_context,
ll_relations_context,
ll_text_units_context,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = (
hl_entities_context,
hl_relations_context,
hl_text_units_context,
)
return f""" return f"""
-----Entities----- -----Entities-----
```csv ```csv

View File

@@ -1,38 +1,38 @@
accelerate accelerate
aioboto3~=13.3.0 aioboto3
aiofiles~=24.1.0 aiofiles
aiohttp~=3.11.11 aiohttp
asyncpg~=0.30.0 asyncpg
# database packages # database packages
graspologic graspologic
gremlinpython gremlinpython
hnswlib hnswlib
nano-vectordb nano-vectordb
neo4j~=5.27.0 neo4j
networkx~=3.2.1 networkx
numpy~=2.2.0 numpy
ollama~=0.4.4 ollama
openai~=1.58.1 openai
oracledb oracledb
psycopg-pool~=3.2.4 psycopg-pool
psycopg[binary,pool]~=3.2.3 psycopg[binary,pool]
pydantic~=2.10.4 pydantic
pymilvus pymilvus
pymongo pymongo
pymysql pymysql
python-dotenv~=1.0.1 python-dotenv
pyvis~=0.3.2 pyvis
setuptools~=70.0.0 setuptools
# lmdeploy[all] # lmdeploy[all]
sqlalchemy~=2.0.36 sqlalchemy
tenacity~=9.0.0 tenacity
# LLM packages # LLM packages
tiktoken~=0.8.0 tiktoken
torch~=2.5.1+cu121 torch
tqdm~=4.67.1 tqdm
transformers~=4.47.1 transformers
xxhash xxhash

View File

@@ -100,10 +100,7 @@ setuptools.setup(
}, },
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"lollms-lightrag-server=lightrag.api.lollms_lightrag_server:main [api]", "lightrag-server=lightrag.api.lightrag_server:main [api]",
"ollama-lightrag-server=lightrag.api.ollama_lightrag_server:main [api]",
"openai-lightrag-server=lightrag.api.openai_lightrag_server:main [api]",
"azure-openai-lightrag-server=lightrag.api.azure_openai_lightrag_server:main [api]",
], ],
}, },
) )