Merge pull request #211 from monk-after-90s/main

Function enhancement
This commit is contained in:
zrguo
2024-11-07 14:47:59 +08:00
committed by GitHub
2 changed files with 38 additions and 30 deletions

View File

@@ -498,6 +498,10 @@ pip install fastapi uvicorn pydantic
2. Set up your environment variables: 2. Set up your environment variables:
```bash ```bash
export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default" export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
export OPENAI_BASE_URL="Your OpenAI API base URL" # Optional: Defaults to "https://api.openai.com/v1"
export OPENAI_API_KEY="Your OpenAI API key" # Required
export LLM_MODEL="Your LLM model" # Optional: Defaults to "gpt-4o-mini"
export EMBEDDING_MODEL="Your embedding model" # Optional: Defaults to "text-embedding-3-large"
``` ```
3. Run the API server: 3. Run the API server:
@@ -522,7 +526,8 @@ The API server provides the following endpoints:
```json ```json
{ {
"query": "Your question here", "query": "Your question here",
"mode": "hybrid" // Can be "naive", "local", "global", or "hybrid" "mode": "hybrid", // Can be "naive", "local", "global", or "hybrid"
"only_need_context": true // Optional: Defaults to false, if true, only the referenced context will be returned, otherwise the llm answer will be returned
} }
``` ```
- **Example:** - **Example:**

View File

@@ -1,4 +1,4 @@
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, File, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
import os import os
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
@@ -18,9 +18,17 @@ app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory # Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}") print(f"WORKING_DIR: {WORKING_DIR}")
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
print(f"LLM_MODEL: {LLM_MODEL}")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
# LLM model function # LLM model function
@@ -28,12 +36,10 @@ async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o-mini", LLM_MODEL,
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
**kwargs, **kwargs,
) )
@@ -44,37 +50,41 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray: async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding( return await openai_embedding(
texts, texts,
model="text-embedding-3-large", model=EMBEDDING_MODEL,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
) )
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"{embedding_dim=}")
return embedding_dim
# Initialize RAG instance # Initialize RAG instance
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()),
embedding_dim=3072, max_token_size=8192, func=embedding_func max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
), func=embedding_func),
) )
# Data models # Data models
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
mode: str = "hybrid" mode: str = "hybrid"
only_need_context: bool = False
class InsertRequest(BaseModel): class InsertRequest(BaseModel):
text: str text: str
class InsertFileRequest(BaseModel):
file_path: str
class Response(BaseModel): class Response(BaseModel):
status: str status: str
data: Optional[str] = None data: Optional[str] = None
@@ -89,7 +99,8 @@ async def query_endpoint(request: QueryRequest):
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
result = await loop.run_in_executor( result = await loop.run_in_executor(
None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode)) None, lambda: rag.query(request.query,
param=QueryParam(mode=request.mode, only_need_context=request.only_need_context))
) )
return Response(status="success", data=result) return Response(status="success", data=result)
except Exception as e: except Exception as e:
@@ -107,30 +118,22 @@ async def insert_endpoint(request: InsertRequest):
@app.post("/insert_file", response_model=Response) @app.post("/insert_file", response_model=Response)
async def insert_file(request: InsertFileRequest): async def insert_file(file: UploadFile = File(...)):
try: try:
# Check if file exists file_content = await file.read()
if not os.path.exists(request.file_path):
raise HTTPException(
status_code=404, detail=f"File not found: {request.file_path}"
)
# Read file content # Read file content
try: try:
with open(request.file_path, "r", encoding="utf-8") as f: content = file_content.decode("utf-8")
content = f.read()
except UnicodeDecodeError: except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings # If UTF-8 decoding fails, try other encodings
with open(request.file_path, "r", encoding="gbk") as f: content = file_content.decode("gbk")
content = f.read()
# Insert file content # Insert file content
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content)) await loop.run_in_executor(None, lambda: rag.insert(content))
return Response( return Response(
status="success", status="success",
message=f"File content from {request.file_path} inserted successfully", message=f"File content from {file.filename} inserted successfully",
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))