cleaned code

This commit is contained in:
Yannick Stephan
2025-02-16 19:13:29 +01:00
parent 709461b875
commit b09589cfd9

View File

@@ -13,13 +13,12 @@ import re
from fastapi.staticfiles import StaticFiles
import logging
import argparse
from typing import List, Any, Optional, Dict
from typing import List, Any, Literal, Optional, Dict
from pydantic import BaseModel
from lightrag import LightRAG, QueryParam
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
@@ -626,23 +625,11 @@ class DocumentManager:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# LightRAG query mode
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
mix = "mix"
class QueryRequest(BaseModel):
query: str
"""Specifies the retrieval mode"""
mode: SearchMode = SearchMode.hybrid
"""If True, enables streaming output for real-time responses."""
stream: Optional[bool] = None
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "hybrid"
"""If True, only returns the retrieved context without generating a response."""
only_need_context: Optional[bool] = None
@@ -688,13 +675,18 @@ class InsertTextRequest(BaseModel):
text: str
class InsertTextsRequest(BaseModel):
texts: list[str]
class InsertResponse(BaseModel):
status: str
message: str
def QueryRequestToQueryParams(request: QueryRequest):
param = QueryParam(mode=request.mode, stream=request.stream)
def QueryRequestToQueryParams(request: QueryRequest, is_stream: bool):
param = QueryParam(mode=request.mode, stream=is_stream)
if request.only_need_context is not None:
param.only_need_context = request.only_need_context
if request.only_need_prompt is not None:
@@ -1653,9 +1645,7 @@ def create_app(args):
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/query",
response_model=QueryResponse,
dependencies=[Depends(optional_api_key)]
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
"""
@@ -1674,7 +1664,7 @@ def create_app(args):
"""
try:
response = await rag.aquery(
request.query, param=QueryRequestToQueryParams(request)
request.query, param=QueryRequestToQueryParams(request, False)
)
# If response is a string (e.g. cache hit), return directly
@@ -1703,9 +1693,8 @@ def create_app(args):
StreamingResponse: A streaming response containing the RAG query results.
"""
try:
params = QueryRequestToQueryParams(request)
params = QueryRequestToQueryParams(request, True)
params.stream = True
response = await rag.aquery( # Use aquery instead of query, and add await
request.query, param=params
)