Fit linting
This commit is contained in:
@@ -4,7 +4,6 @@ LightRAG FastAPI Server
|
|||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
FastAPI,
|
FastAPI,
|
||||||
HTTPException,
|
|
||||||
Depends,
|
Depends,
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
@@ -14,7 +13,7 @@ import os
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Optional, Dict
|
from typing import Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import configparser
|
import configparser
|
||||||
from ascii_colors import ASCIIColors
|
from ascii_colors import ASCIIColors
|
||||||
@@ -73,6 +72,7 @@ scan_progress: Dict = {
|
|||||||
# Lock for thread-safe operations
|
# Lock for thread-safe operations
|
||||||
progress_lock = threading.Lock()
|
progress_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_default_host(binding_type: str) -> str:
|
def get_default_host(binding_type: str) -> str:
|
||||||
default_hosts = {
|
default_hosts = {
|
||||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||||
@@ -624,7 +624,9 @@ def create_app(args):
|
|||||||
scan_progress["indexed_count"] = 0
|
scan_progress["indexed_count"] = 0
|
||||||
scan_progress["progress"] = 0
|
scan_progress["progress"] = 0
|
||||||
# Create background task
|
# Create background task
|
||||||
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
task = asyncio.create_task(
|
||||||
|
run_scanning_process(rag, doc_manager)
|
||||||
|
)
|
||||||
app.state.background_tasks.add(task)
|
app.state.background_tasks.add(task)
|
||||||
task.add_done_callback(app.state.background_tasks.discard)
|
task.add_done_callback(app.state.background_tasks.discard)
|
||||||
ASCIIColors.info(
|
ASCIIColors.info(
|
||||||
@@ -876,7 +878,11 @@ def create_app(args):
|
|||||||
# Webui mount webui/index.html
|
# Webui mount webui/index.html
|
||||||
static_dir = Path(__file__).parent / "webui"
|
static_dir = Path(__file__).parent / "webui"
|
||||||
static_dir.mkdir(exist_ok=True)
|
static_dir.mkdir(exist_ok=True)
|
||||||
app.mount("/webui", StaticFiles(directory=static_dir, html=True, check_dir=True), name="webui")
|
app.mount(
|
||||||
|
"/webui",
|
||||||
|
StaticFiles(directory=static_dir, html=True, check_dir=True),
|
||||||
|
name="webui",
|
||||||
|
)
|
||||||
|
|
||||||
@app.get("/webui/")
|
@app.get("/webui/")
|
||||||
async def webui_root():
|
async def webui_root():
|
||||||
|
@@ -14,9 +14,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||||
from fastapi.security import APIKeyHeader
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
|
||||||
|
|
||||||
from lightrag.base import DocProcessingStatus, DocStatus
|
from lightrag.base import DocProcessingStatus, DocStatus
|
||||||
from ..utils_api import get_api_key_dependency
|
from ..utils_api import get_api_key_dependency
|
||||||
@@ -39,6 +37,7 @@ progress_lock = asyncio.Lock()
|
|||||||
# Temporary file prefix
|
# Temporary file prefix
|
||||||
temp_prefix = "__tmp__"
|
temp_prefix = "__tmp__"
|
||||||
|
|
||||||
|
|
||||||
class InsertTextRequest(BaseModel):
|
class InsertTextRequest(BaseModel):
|
||||||
text: str = Field(
|
text: str = Field(
|
||||||
min_length=1,
|
min_length=1,
|
||||||
@@ -50,6 +49,7 @@ class InsertTextRequest(BaseModel):
|
|||||||
def strip_after(cls, text: str) -> str:
|
def strip_after(cls, text: str) -> str:
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
class InsertTextsRequest(BaseModel):
|
class InsertTextsRequest(BaseModel):
|
||||||
texts: list[str] = Field(
|
texts: list[str] = Field(
|
||||||
min_length=1,
|
min_length=1,
|
||||||
@@ -61,10 +61,12 @@ class InsertTextsRequest(BaseModel):
|
|||||||
def strip_after(cls, texts: list[str]) -> list[str]:
|
def strip_after(cls, texts: list[str]) -> list[str]:
|
||||||
return [text.strip() for text in texts]
|
return [text.strip() for text in texts]
|
||||||
|
|
||||||
|
|
||||||
class InsertResponse(BaseModel):
|
class InsertResponse(BaseModel):
|
||||||
status: str = Field(description="Status of the operation")
|
status: str = Field(description="Status of the operation")
|
||||||
message: str = Field(description="Message describing the operation result")
|
message: str = Field(description="Message describing the operation result")
|
||||||
|
|
||||||
|
|
||||||
class DocStatusResponse(BaseModel):
|
class DocStatusResponse(BaseModel):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format_datetime(dt: Any) -> Optional[str]:
|
def format_datetime(dt: Any) -> Optional[str]:
|
||||||
@@ -84,9 +86,11 @@ class DocStatusResponse(BaseModel):
|
|||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
metadata: Optional[dict[str, Any]] = None
|
metadata: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class DocsStatusesResponse(BaseModel):
|
class DocsStatusesResponse(BaseModel):
|
||||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
||||||
|
|
||||||
|
|
||||||
class DocumentManager:
|
class DocumentManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -129,6 +133,7 @@ class DocumentManager:
|
|||||||
def is_supported_file(self, filename: str) -> bool:
|
def is_supported_file(self, filename: str) -> bool:
|
||||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||||
try:
|
try:
|
||||||
content = ""
|
content = ""
|
||||||
@@ -145,7 +150,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
|||||||
case ".pdf":
|
case ".pdf":
|
||||||
if not pm.is_installed("pypdf2"):
|
if not pm.is_installed("pypdf2"):
|
||||||
pm.install("pypdf2")
|
pm.install("pypdf2")
|
||||||
from PyPDF2 import PdfReader # type: ignore
|
from PyPDF2 import PdfReader # type: ignore
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
pdf_file = BytesIO(file)
|
pdf_file = BytesIO(file)
|
||||||
@@ -184,10 +189,17 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
|||||||
for sheet in wb:
|
for sheet in wb:
|
||||||
content += f"Sheet: {sheet.title}\n"
|
content += f"Sheet: {sheet.title}\n"
|
||||||
for row in sheet.iter_rows(values_only=True):
|
for row in sheet.iter_rows(values_only=True):
|
||||||
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
|
content += (
|
||||||
|
"\t".join(
|
||||||
|
str(cell) if cell is not None else "" for cell in row
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
content += "\n"
|
content += "\n"
|
||||||
case _:
|
case _:
|
||||||
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
|
logging.error(
|
||||||
|
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Insert into the RAG queue
|
# Insert into the RAG queue
|
||||||
@@ -209,6 +221,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
|||||||
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_index_file(rag, file_path: Path):
|
async def pipeline_index_file(rag, file_path: Path):
|
||||||
"""Index a file
|
"""Index a file
|
||||||
|
|
||||||
@@ -231,7 +244,7 @@ async def pipeline_index_file(rag, file_path: Path):
|
|||||||
case ".pdf":
|
case ".pdf":
|
||||||
if not pm.is_installed("pypdf2"):
|
if not pm.is_installed("pypdf2"):
|
||||||
pm.install("pypdf2")
|
pm.install("pypdf2")
|
||||||
from PyPDF2 import PdfReader # type: ignore
|
from PyPDF2 import PdfReader # type: ignore
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
pdf_file = BytesIO(file)
|
pdf_file = BytesIO(file)
|
||||||
@@ -270,10 +283,17 @@ async def pipeline_index_file(rag, file_path: Path):
|
|||||||
for sheet in wb:
|
for sheet in wb:
|
||||||
content += f"Sheet: {sheet.title}\n"
|
content += f"Sheet: {sheet.title}\n"
|
||||||
for row in sheet.iter_rows(values_only=True):
|
for row in sheet.iter_rows(values_only=True):
|
||||||
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
|
content += (
|
||||||
|
"\t".join(
|
||||||
|
str(cell) if cell is not None else "" for cell in row
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
content += "\n"
|
content += "\n"
|
||||||
case _:
|
case _:
|
||||||
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
|
logging.error(
|
||||||
|
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Insert into the RAG queue
|
# Insert into the RAG queue
|
||||||
@@ -288,6 +308,7 @@ async def pipeline_index_file(rag, file_path: Path):
|
|||||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_index_files(rag, file_paths: List[Path]):
|
async def pipeline_index_files(rag, file_paths: List[Path]):
|
||||||
if not file_paths:
|
if not file_paths:
|
||||||
return
|
return
|
||||||
@@ -305,12 +326,14 @@ async def pipeline_index_files(rag, file_paths: List[Path]):
|
|||||||
logging.error(f"Error indexing files: {str(e)}")
|
logging.error(f"Error indexing files: {str(e)}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_index_texts(rag, texts: List[str]):
|
async def pipeline_index_texts(rag, texts: List[str]):
|
||||||
if not texts:
|
if not texts:
|
||||||
return
|
return
|
||||||
await rag.apipeline_enqueue_documents(texts)
|
await rag.apipeline_enqueue_documents(texts)
|
||||||
await rag.apipeline_process_enqueue_documents()
|
await rag.apipeline_process_enqueue_documents()
|
||||||
|
|
||||||
|
|
||||||
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
|
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
|
||||||
@@ -320,6 +343,7 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
|||||||
shutil.copyfileobj(file.file, buffer)
|
shutil.copyfileobj(file.file, buffer)
|
||||||
return temp_path
|
return temp_path
|
||||||
|
|
||||||
|
|
||||||
async def run_scanning_process(rag, doc_manager: DocumentManager):
|
async def run_scanning_process(rag, doc_manager: DocumentManager):
|
||||||
"""Background task to scan and index documents"""
|
"""Background task to scan and index documents"""
|
||||||
try:
|
try:
|
||||||
@@ -349,7 +373,10 @@ async def run_scanning_process(rag, doc_manager: DocumentManager):
|
|||||||
async with progress_lock:
|
async with progress_lock:
|
||||||
scan_progress["is_scanning"] = False
|
scan_progress["is_scanning"] = False
|
||||||
|
|
||||||
def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None):
|
|
||||||
|
def create_document_routes(
|
||||||
|
rag, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||||
|
):
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
optional_api_key = get_api_key_dependency(api_key)
|
||||||
|
|
||||||
@router.post("/scan", dependencies=[Depends(optional_api_key)])
|
@router.post("/scan", dependencies=[Depends(optional_api_key)])
|
||||||
@@ -437,8 +464,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
@router.post(
|
||||||
async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks):
|
"/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||||
|
)
|
||||||
|
async def insert_text(
|
||||||
|
request: InsertTextRequest, background_tasks: BackgroundTasks
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Insert text into the RAG system.
|
Insert text into the RAG system.
|
||||||
|
|
||||||
@@ -466,8 +497,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
@router.post(
|
||||||
async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks):
|
"/texts",
|
||||||
|
response_model=InsertResponse,
|
||||||
|
dependencies=[Depends(optional_api_key)],
|
||||||
|
)
|
||||||
|
async def insert_texts(
|
||||||
|
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Insert multiple texts into the RAG system.
|
Insert multiple texts into the RAG system.
|
||||||
|
|
||||||
@@ -495,8 +532,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
@router.post(
|
||||||
async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
|
"/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||||
|
)
|
||||||
|
async def insert_file(
|
||||||
|
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Insert a file directly into the RAG system.
|
Insert a file directly into the RAG system.
|
||||||
|
|
||||||
@@ -532,8 +573,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
@router.post(
|
||||||
async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
|
"/file_batch",
|
||||||
|
response_model=InsertResponse,
|
||||||
|
dependencies=[Depends(optional_api_key)],
|
||||||
|
)
|
||||||
|
async def insert_batch(
|
||||||
|
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Process multiple files in batch mode.
|
Process multiple files in batch mode.
|
||||||
|
|
||||||
@@ -587,7 +634,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
@router.delete(
|
||||||
|
"", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||||
|
)
|
||||||
async def clear_documents():
|
async def clear_documents():
|
||||||
"""
|
"""
|
||||||
Clear all documents from the RAG system.
|
Clear all documents from the RAG system.
|
||||||
@@ -605,7 +654,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
|||||||
rag.text_chunks = []
|
rag.text_chunks = []
|
||||||
rag.entities_vdb = None
|
rag.entities_vdb = None
|
||||||
rag.relationships_vdb = None
|
rag.relationships_vdb = None
|
||||||
return InsertResponse(status="success", message="All documents cleared successfully")
|
return InsertResponse(
|
||||||
|
status="success", message="All documents cleared successfully"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error DELETE /documents: {str(e)}")
|
logging.error(f"Error DELETE /documents: {str(e)}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
@@ -651,8 +702,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
|||||||
content_summary=doc_status.content_summary,
|
content_summary=doc_status.content_summary,
|
||||||
content_length=doc_status.content_length,
|
content_length=doc_status.content_length,
|
||||||
status=doc_status.status,
|
status=doc_status.status,
|
||||||
created_at=DocStatusResponse.format_datetime(doc_status.created_at),
|
created_at=DocStatusResponse.format_datetime(
|
||||||
updated_at=DocStatusResponse.format_datetime(doc_status.updated_at),
|
doc_status.created_at
|
||||||
|
),
|
||||||
|
updated_at=DocStatusResponse.format_datetime(
|
||||||
|
doc_status.updated_at
|
||||||
|
),
|
||||||
chunks_count=doc_status.chunks_count,
|
chunks_count=doc_status.chunks_count,
|
||||||
error=doc_status.error,
|
error=doc_status.error,
|
||||||
metadata=doc_status.metadata,
|
metadata=doc_status.metadata,
|
||||||
|
@@ -4,12 +4,13 @@ This module contains all graph-related routes for the LightRAG API.
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from ..utils_api import get_api_key_dependency
|
from ..utils_api import get_api_key_dependency
|
||||||
|
|
||||||
router = APIRouter(tags=["graph"])
|
router = APIRouter(tags=["graph"])
|
||||||
|
|
||||||
|
|
||||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
optional_api_key = get_api_key_dependency(api_key)
|
||||||
|
|
||||||
|
@@ -4,7 +4,6 @@ This module contains all query-related routes for the LightRAG API.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
@@ -16,6 +15,7 @@ from ascii_colors import trace_exception
|
|||||||
|
|
||||||
router = APIRouter(tags=["query"])
|
router = APIRouter(tags=["query"])
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str = Field(
|
query: str = Field(
|
||||||
min_length=1,
|
min_length=1,
|
||||||
@@ -131,15 +131,19 @@ class QueryRequest(BaseModel):
|
|||||||
param.stream = is_stream
|
param.stream = is_stream
|
||||||
return param
|
return param
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
response: str = Field(
|
response: str = Field(
|
||||||
description="The generated response",
|
description="The generated response",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
optional_api_key = get_api_key_dependency(api_key)
|
||||||
|
|
||||||
@router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
|
@router.post(
|
||||||
|
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
|
||||||
|
)
|
||||||
async def query_text(request: QueryRequest):
|
async def query_text(request: QueryRequest):
|
||||||
"""
|
"""
|
||||||
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
||||||
|
@@ -7,6 +7,7 @@ from fastapi import HTTPException, Security
|
|||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_dependency(api_key: Optional[str]):
|
def get_api_key_dependency(api_key: Optional[str]):
|
||||||
"""
|
"""
|
||||||
Create an API key dependency for route protection.
|
Create an API key dependency for route protection.
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
future
|
|
||||||
aiohttp
|
aiohttp
|
||||||
configparser
|
configparser
|
||||||
|
future
|
||||||
|
|
||||||
# Basic modules
|
# Basic modules
|
||||||
numpy
|
numpy
|
||||||
|
Reference in New Issue
Block a user