cleaned code
This commit is contained in:
@@ -13,13 +13,12 @@ import re
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import logging
|
||||
import argparse
|
||||
from typing import List, Any, Optional, Dict
|
||||
from typing import List, Any, Literal, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import aiofiles
|
||||
@@ -626,23 +625,11 @@ class DocumentManager:
|
||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||
|
||||
|
||||
# LightRAG query mode
|
||||
class SearchMode(str, Enum):
|
||||
naive = "naive"
|
||||
local = "local"
|
||||
global_ = "global"
|
||||
hybrid = "hybrid"
|
||||
mix = "mix"
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
"""Specifies the retrieval mode"""
|
||||
mode: SearchMode = SearchMode.hybrid
|
||||
|
||||
"""If True, enables streaming output for real-time responses."""
|
||||
stream: Optional[bool] = None
|
||||
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "hybrid"
|
||||
|
||||
"""If True, only returns the retrieved context without generating a response."""
|
||||
only_need_context: Optional[bool] = None
|
||||
@@ -688,13 +675,18 @@ class InsertTextRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class InsertTextsRequest(BaseModel):
|
||||
texts: list[str]
|
||||
|
||||
|
||||
class InsertResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
def QueryRequestToQueryParams(request: QueryRequest):
|
||||
param = QueryParam(mode=request.mode, stream=request.stream)
|
||||
def QueryRequestToQueryParams(request: QueryRequest, is_stream: bool):
|
||||
param = QueryParam(mode=request.mode, stream=is_stream)
|
||||
|
||||
if request.only_need_context is not None:
|
||||
param.only_need_context = request.only_need_context
|
||||
if request.only_need_prompt is not None:
|
||||
@@ -1523,7 +1515,7 @@ def create_app(args):
|
||||
logging.error(f"Error /documents/text: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post(
|
||||
"/documents/file",
|
||||
response_model=InsertResponse,
|
||||
@@ -1653,9 +1645,7 @@ def create_app(args):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post(
|
||||
"/query",
|
||||
response_model=QueryResponse,
|
||||
dependencies=[Depends(optional_api_key)]
|
||||
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
|
||||
)
|
||||
async def query_text(request: QueryRequest):
|
||||
"""
|
||||
@@ -1674,7 +1664,7 @@ def create_app(args):
|
||||
"""
|
||||
try:
|
||||
response = await rag.aquery(
|
||||
request.query, param=QueryRequestToQueryParams(request)
|
||||
request.query, param=QueryRequestToQueryParams(request, False)
|
||||
)
|
||||
|
||||
# If response is a string (e.g. cache hit), return directly
|
||||
@@ -1703,9 +1693,8 @@ def create_app(args):
|
||||
StreamingResponse: A streaming response containing the RAG query results.
|
||||
"""
|
||||
try:
|
||||
params = QueryRequestToQueryParams(request)
|
||||
params = QueryRequestToQueryParams(request, True)
|
||||
|
||||
params.stream = True
|
||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||
request.query, param=params
|
||||
)
|
||||
|
Reference in New Issue
Block a user