cleaned code
This commit is contained in:
@@ -13,13 +13,12 @@ import re
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
from typing import List, Any, Optional, Dict
|
from typing import List, Any, Literal, Optional, Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.types import GPTKeywordExtractionFormat
|
from lightrag.types import GPTKeywordExtractionFormat
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
import aiofiles
|
import aiofiles
|
||||||
@@ -626,23 +625,11 @@ class DocumentManager:
|
|||||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||||
|
|
||||||
|
|
||||||
# LightRAG query mode
|
|
||||||
class SearchMode(str, Enum):
|
|
||||||
naive = "naive"
|
|
||||||
local = "local"
|
|
||||||
global_ = "global"
|
|
||||||
hybrid = "hybrid"
|
|
||||||
mix = "mix"
|
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
"""Specifies the retrieval mode"""
|
"""Specifies the retrieval mode"""
|
||||||
mode: SearchMode = SearchMode.hybrid
|
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "hybrid"
|
||||||
|
|
||||||
"""If True, enables streaming output for real-time responses."""
|
|
||||||
stream: Optional[bool] = None
|
|
||||||
|
|
||||||
"""If True, only returns the retrieved context without generating a response."""
|
"""If True, only returns the retrieved context without generating a response."""
|
||||||
only_need_context: Optional[bool] = None
|
only_need_context: Optional[bool] = None
|
||||||
@@ -688,13 +675,18 @@ class InsertTextRequest(BaseModel):
|
|||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class InsertTextsRequest(BaseModel):
|
||||||
|
texts: list[str]
|
||||||
|
|
||||||
|
|
||||||
class InsertResponse(BaseModel):
|
class InsertResponse(BaseModel):
|
||||||
status: str
|
status: str
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
def QueryRequestToQueryParams(request: QueryRequest):
|
def QueryRequestToQueryParams(request: QueryRequest, is_stream: bool):
|
||||||
param = QueryParam(mode=request.mode, stream=request.stream)
|
param = QueryParam(mode=request.mode, stream=is_stream)
|
||||||
|
|
||||||
if request.only_need_context is not None:
|
if request.only_need_context is not None:
|
||||||
param.only_need_context = request.only_need_context
|
param.only_need_context = request.only_need_context
|
||||||
if request.only_need_prompt is not None:
|
if request.only_need_prompt is not None:
|
||||||
@@ -1523,7 +1515,7 @@ def create_app(args):
|
|||||||
logging.error(f"Error /documents/text: {str(e)}")
|
logging.error(f"Error /documents/text: {str(e)}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/documents/file",
|
"/documents/file",
|
||||||
response_model=InsertResponse,
|
response_model=InsertResponse,
|
||||||
@@ -1653,9 +1645,7 @@ def create_app(args):
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/query",
|
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
|
||||||
response_model=QueryResponse,
|
|
||||||
dependencies=[Depends(optional_api_key)]
|
|
||||||
)
|
)
|
||||||
async def query_text(request: QueryRequest):
|
async def query_text(request: QueryRequest):
|
||||||
"""
|
"""
|
||||||
@@ -1674,7 +1664,7 @@ def create_app(args):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await rag.aquery(
|
response = await rag.aquery(
|
||||||
request.query, param=QueryRequestToQueryParams(request)
|
request.query, param=QueryRequestToQueryParams(request, False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# If response is a string (e.g. cache hit), return directly
|
# If response is a string (e.g. cache hit), return directly
|
||||||
@@ -1703,9 +1693,8 @@ def create_app(args):
|
|||||||
StreamingResponse: A streaming response containing the RAG query results.
|
StreamingResponse: A streaming response containing the RAG query results.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
params = QueryRequestToQueryParams(request)
|
params = QueryRequestToQueryParams(request, True)
|
||||||
|
|
||||||
params.stream = True
|
|
||||||
response = await rag.aquery( # Use aquery instead of query, and add await
|
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||||
request.query, param=params
|
request.query, param=params
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user