Fit linting

This commit is contained in:
yangdx
2025-02-20 04:12:21 +08:00
parent f776db0779
commit a8abcf14ac
7 changed files with 98 additions and 31 deletions

2
.gitignore vendored
View File

@@ -62,4 +62,4 @@ lightrag-dev/
gui/
# unit-test files
test_*
test_*

View File

@@ -4,7 +4,6 @@ LightRAG FastAPI Server
from fastapi import (
FastAPI,
HTTPException,
Depends,
)
from fastapi.responses import FileResponse
@@ -14,7 +13,7 @@ import os
from fastapi.staticfiles import StaticFiles
import logging
import argparse
from typing import Optional, Dict
from typing import Dict
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
@@ -73,6 +72,7 @@ scan_progress: Dict = {
# Lock for thread-safe operations
progress_lock = threading.Lock()
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
@@ -624,7 +624,9 @@ def create_app(args):
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
# Create background task
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
task = asyncio.create_task(
run_scanning_process(rag, doc_manager)
)
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
ASCIIColors.info(
@@ -876,8 +878,12 @@ def create_app(args):
# Webui mount webui/index.html
static_dir = Path(__file__).parent / "webui"
static_dir.mkdir(exist_ok=True)
app.mount("/webui", StaticFiles(directory=static_dir, html=True, check_dir=True), name="webui")
app.mount(
"/webui",
StaticFiles(directory=static_dir, html=True, check_dir=True),
name="webui",
)
@app.get("/webui/")
async def webui_root():
return FileResponse(static_dir / "index.html")

View File

@@ -14,9 +14,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Any
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field, field_validator
from starlette.status import HTTP_403_FORBIDDEN
from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency
@@ -39,6 +37,7 @@ progress_lock = asyncio.Lock()
# Temporary file prefix
temp_prefix = "__tmp__"
class InsertTextRequest(BaseModel):
text: str = Field(
min_length=1,
@@ -50,6 +49,7 @@ class InsertTextRequest(BaseModel):
def strip_after(cls, text: str) -> str:
return text.strip()
class InsertTextsRequest(BaseModel):
texts: list[str] = Field(
min_length=1,
@@ -61,10 +61,12 @@ class InsertTextsRequest(BaseModel):
def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts]
class InsertResponse(BaseModel):
status: str = Field(description="Status of the operation")
message: str = Field(description="Message describing the operation result")
class DocStatusResponse(BaseModel):
@staticmethod
def format_datetime(dt: Any) -> Optional[str]:
@@ -84,9 +86,11 @@ class DocStatusResponse(BaseModel):
error: Optional[str] = None
metadata: Optional[dict[str, Any]] = None
class DocsStatusesResponse(BaseModel):
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
class DocumentManager:
def __init__(
self,
@@ -129,6 +133,7 @@ class DocumentManager:
def is_supported_file(self, filename: str) -> bool:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
try:
content = ""
@@ -145,7 +150,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
@@ -184,10 +189,17 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
content += (
"\t".join(
str(cell) if cell is not None else "" for cell in row
)
+ "\n"
)
content += "\n"
case _:
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return False
# Insert into the RAG queue
@@ -209,6 +221,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
logging.error(f"Error deleting file {file_path}: {str(e)}")
return False
async def pipeline_index_file(rag, file_path: Path):
"""Index a file
@@ -231,7 +244,7 @@ async def pipeline_index_file(rag, file_path: Path):
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
@@ -270,10 +283,17 @@ async def pipeline_index_file(rag, file_path: Path):
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
content += (
"\t".join(
str(cell) if cell is not None else "" for cell in row
)
+ "\n"
)
content += "\n"
case _:
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return
# Insert into the RAG queue
@@ -288,6 +308,7 @@ async def pipeline_index_file(rag, file_path: Path):
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_files(rag, file_paths: List[Path]):
if not file_paths:
return
@@ -305,12 +326,14 @@ async def pipeline_index_files(rag, file_paths: List[Path]):
logging.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc())
async def pipeline_index_texts(rag, texts: List[str]):
if not texts:
return
await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents()
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
@@ -320,6 +343,7 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
shutil.copyfileobj(file.file, buffer)
return temp_path
async def run_scanning_process(rag, doc_manager: DocumentManager):
"""Background task to scan and index documents"""
try:
@@ -349,7 +373,10 @@ async def run_scanning_process(rag, doc_manager: DocumentManager):
async with progress_lock:
scan_progress["is_scanning"] = False
def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None):
def create_document_routes(
rag, doc_manager: DocumentManager, api_key: Optional[str] = None
):
optional_api_key = get_api_key_dependency(api_key)
@router.post("/scan", dependencies=[Depends(optional_api_key)])
@@ -437,8 +464,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks):
@router.post(
"/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
)
async def insert_text(
request: InsertTextRequest, background_tasks: BackgroundTasks
):
"""
Insert text into the RAG system.
@@ -466,8 +497,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks):
@router.post(
"/texts",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_texts(
request: InsertTextsRequest, background_tasks: BackgroundTasks
):
"""
Insert multiple texts into the RAG system.
@@ -495,8 +532,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
@router.post(
"/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
)
async def insert_file(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
"""
Insert a file directly into the RAG system.
@@ -532,8 +573,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
@router.post(
"/file_batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
):
"""
Process multiple files in batch mode.
@@ -587,7 +634,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
@router.delete(
"", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
)
async def clear_documents():
"""
Clear all documents from the RAG system.
@@ -605,7 +654,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(status="success", message="All documents cleared successfully")
return InsertResponse(
status="success", message="All documents cleared successfully"
)
except Exception as e:
logging.error(f"Error DELETE /documents: {str(e)}")
logging.error(traceback.format_exc())
@@ -651,8 +702,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
content_summary=doc_status.content_summary,
content_length=doc_status.content_length,
status=doc_status.status,
created_at=DocStatusResponse.format_datetime(doc_status.created_at),
updated_at=DocStatusResponse.format_datetime(doc_status.updated_at),
created_at=DocStatusResponse.format_datetime(
doc_status.created_at
),
updated_at=DocStatusResponse.format_datetime(
doc_status.updated_at
),
chunks_count=doc_status.chunks_count,
error=doc_status.error,
metadata=doc_status.metadata,

View File

@@ -4,12 +4,13 @@ This module contains all graph-related routes for the LightRAG API.
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency
router = APIRouter(tags=["graph"])
def create_graph_routes(rag, api_key: Optional[str] = None):
optional_api_key = get_api_key_dependency(api_key)

View File

@@ -4,7 +4,6 @@ This module contains all query-related routes for the LightRAG API.
import json
import logging
import traceback
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
@@ -16,6 +15,7 @@ from ascii_colors import trace_exception
router = APIRouter(tags=["query"])
class QueryRequest(BaseModel):
query: str = Field(
min_length=1,
@@ -131,15 +131,19 @@ class QueryRequest(BaseModel):
param.stream = is_stream
return param
class QueryResponse(BaseModel):
response: str = Field(
description="The generated response",
)
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
optional_api_key = get_api_key_dependency(api_key)
@router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
@router.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
"""
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.

View File

@@ -7,6 +7,7 @@ from fastapi import HTTPException, Security
from fastapi.security import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN
def get_api_key_dependency(api_key: Optional[str]):
"""
Create an API key dependency for route protection.

View File

@@ -1,6 +1,6 @@
future
aiohttp
configparser
future
# Basic modules
numpy