diff --git a/README.md b/README.md index 6602f1d3..33d36ae1 100644 --- a/README.md +++ b/README.md @@ -498,6 +498,10 @@ pip install fastapi uvicorn pydantic 2. Set up your environment variables: ```bash export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default" +export OPENAI_BASE_URL="Your OpenAI API base URL" # Optional: Defaults to "https://api.openai.com/v1" +export OPENAI_API_KEY="Your OpenAI API key" # Required +export LLM_MODEL="Your LLM model" # Optional: Defaults to "gpt-4o-mini" +export EMBEDDING_MODEL="Your embedding model" # Optional: Defaults to "text-embedding-3-large" ``` 3. Run the API server: @@ -522,7 +526,8 @@ The API server provides the following endpoints: ```json { "query": "Your question here", - "mode": "hybrid" // Can be "naive", "local", "global", or "hybrid" + "mode": "hybrid", // Can be "naive", "local", "global", or "hybrid" + "only_need_context": true // Optional: Defaults to false, if true, only the referenced context will be returned, otherwise the llm answer will be returned } ``` - **Example:** diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index 2cd262bb..20a05a5f 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, File, UploadFile from pydantic import BaseModel import os from lightrag import LightRAG, QueryParam @@ -18,22 +18,28 @@ app = FastAPI(title="LightRAG API", description="API for RAG operations") # Configure working directory WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") print(f"WORKING_DIR: {WORKING_DIR}") +LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini") +print(f"LLM_MODEL: {LLM_MODEL}") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") +print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) +print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") + if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + # LLM model function async def llm_model_func( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: return await openai_complete_if_cache( - "gpt-4o-mini", + LLM_MODEL, prompt, system_prompt=system_prompt, history_messages=history_messages, - api_key="YOUR_API_KEY", - base_url="YourURL/v1", **kwargs, ) @@ -44,37 +50,41 @@ async def llm_model_func( async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, - model="text-embedding-3-large", - api_key="YOUR_API_KEY", - base_url="YourURL/v1", + model=EMBEDDING_MODEL, ) +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + print(f"{embedding_dim=}") + return embedding_dim + + # Initialize RAG instance rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=3072, max_token_size=8192, func=embedding_func - ), + embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()), + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, + func=embedding_func), ) + # Data models class QueryRequest(BaseModel): query: str mode: str = "hybrid" + only_need_context: bool = False class InsertRequest(BaseModel): text: str -class InsertFileRequest(BaseModel): - file_path: str - - class Response(BaseModel): status: str data: Optional[str] = None @@ -89,7 +99,8 @@ async def query_endpoint(request: QueryRequest): try: loop = asyncio.get_event_loop() result = await loop.run_in_executor( - None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode)) + None, lambda: rag.query(request.query, + param=QueryParam(mode=request.mode, only_need_context=request.only_need_context)) ) return Response(status="success", data=result) except Exception as e: @@ -107,30 +118,22 @@ async def insert_endpoint(request: InsertRequest): @app.post("/insert_file", response_model=Response) -async def insert_file(request: InsertFileRequest): +async def insert_file(file: UploadFile = File(...)): try: - # Check if file exists - if not os.path.exists(request.file_path): - raise HTTPException( - status_code=404, detail=f"File not found: {request.file_path}" - ) - + file_content = await file.read() # Read file content try: - with open(request.file_path, "r", encoding="utf-8") as f: - content = f.read() + content = file_content.decode("utf-8") except UnicodeDecodeError: # If UTF-8 decoding fails, try other encodings - with open(request.file_path, "r", encoding="gbk") as f: - content = f.read() - + content = file_content.decode("gbk") # Insert file content loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: rag.insert(content)) return Response( status="success", - message=f"File content from {request.file_path} inserted successfully", + message=f"File content from {file.filename} inserted successfully", ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))