Fit linting

This commit is contained in:
yangdx
2025-02-20 04:12:21 +08:00
parent f776db0779
commit a8abcf14ac
7 changed files with 98 additions and 31 deletions

View File

@@ -4,7 +4,6 @@ LightRAG FastAPI Server
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
HTTPException,
Depends, Depends,
) )
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@@ -14,7 +13,7 @@ import os
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
import logging import logging
import argparse import argparse
from typing import Optional, Dict from typing import Dict
from pathlib import Path from pathlib import Path
import configparser import configparser
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
@@ -73,6 +72,7 @@ scan_progress: Dict = {
# Lock for thread-safe operations # Lock for thread-safe operations
progress_lock = threading.Lock() progress_lock = threading.Lock()
def get_default_host(binding_type: str) -> str: def get_default_host(binding_type: str) -> str:
default_hosts = { default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
@@ -624,7 +624,9 @@ def create_app(args):
scan_progress["indexed_count"] = 0 scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0 scan_progress["progress"] = 0
# Create background task # Create background task
task = asyncio.create_task(run_scanning_process(rag, doc_manager)) task = asyncio.create_task(
run_scanning_process(rag, doc_manager)
)
app.state.background_tasks.add(task) app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard) task.add_done_callback(app.state.background_tasks.discard)
ASCIIColors.info( ASCIIColors.info(
@@ -876,7 +878,11 @@ def create_app(args):
# Webui mount webui/index.html # Webui mount webui/index.html
static_dir = Path(__file__).parent / "webui" static_dir = Path(__file__).parent / "webui"
static_dir.mkdir(exist_ok=True) static_dir.mkdir(exist_ok=True)
app.mount("/webui", StaticFiles(directory=static_dir, html=True, check_dir=True), name="webui") app.mount(
"/webui",
StaticFiles(directory=static_dir, html=True, check_dir=True),
name="webui",
)
@app.get("/webui/") @app.get("/webui/")
async def webui_root(): async def webui_root():

View File

@@ -14,9 +14,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from starlette.status import HTTP_403_FORBIDDEN
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency
@@ -39,6 +37,7 @@ progress_lock = asyncio.Lock()
# Temporary file prefix # Temporary file prefix
temp_prefix = "__tmp__" temp_prefix = "__tmp__"
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str = Field( text: str = Field(
min_length=1, min_length=1,
@@ -50,6 +49,7 @@ class InsertTextRequest(BaseModel):
def strip_after(cls, text: str) -> str: def strip_after(cls, text: str) -> str:
return text.strip() return text.strip()
class InsertTextsRequest(BaseModel): class InsertTextsRequest(BaseModel):
texts: list[str] = Field( texts: list[str] = Field(
min_length=1, min_length=1,
@@ -61,10 +61,12 @@ class InsertTextsRequest(BaseModel):
def strip_after(cls, texts: list[str]) -> list[str]: def strip_after(cls, texts: list[str]) -> list[str]:
return [text.strip() for text in texts] return [text.strip() for text in texts]
class InsertResponse(BaseModel): class InsertResponse(BaseModel):
status: str = Field(description="Status of the operation") status: str = Field(description="Status of the operation")
message: str = Field(description="Message describing the operation result") message: str = Field(description="Message describing the operation result")
class DocStatusResponse(BaseModel): class DocStatusResponse(BaseModel):
@staticmethod @staticmethod
def format_datetime(dt: Any) -> Optional[str]: def format_datetime(dt: Any) -> Optional[str]:
@@ -84,9 +86,11 @@ class DocStatusResponse(BaseModel):
error: Optional[str] = None error: Optional[str] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None
class DocsStatusesResponse(BaseModel): class DocsStatusesResponse(BaseModel):
statuses: Dict[DocStatus, List[DocStatusResponse]] = {} statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
class DocumentManager: class DocumentManager:
def __init__( def __init__(
self, self,
@@ -129,6 +133,7 @@ class DocumentManager:
def is_supported_file(self, filename: str) -> bool: def is_supported_file(self, filename: str) -> bool:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions) return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
async def pipeline_enqueue_file(rag, file_path: Path) -> bool: async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
try: try:
content = "" content = ""
@@ -184,10 +189,17 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
for sheet in wb: for sheet in wb:
content += f"Sheet: {sheet.title}\n" content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True): for row in sheet.iter_rows(values_only=True):
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" content += (
"\t".join(
str(cell) if cell is not None else "" for cell in row
)
+ "\n"
)
content += "\n" content += "\n"
case _: case _:
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return False return False
# Insert into the RAG queue # Insert into the RAG queue
@@ -209,6 +221,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
logging.error(f"Error deleting file {file_path}: {str(e)}") logging.error(f"Error deleting file {file_path}: {str(e)}")
return False return False
async def pipeline_index_file(rag, file_path: Path): async def pipeline_index_file(rag, file_path: Path):
"""Index a file """Index a file
@@ -270,10 +283,17 @@ async def pipeline_index_file(rag, file_path: Path):
for sheet in wb: for sheet in wb:
content += f"Sheet: {sheet.title}\n" content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True): for row in sheet.iter_rows(values_only=True):
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n" content += (
"\t".join(
str(cell) if cell is not None else "" for cell in row
)
+ "\n"
)
content += "\n" content += "\n"
case _: case _:
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})") logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return return
# Insert into the RAG queue # Insert into the RAG queue
@@ -288,6 +308,7 @@ async def pipeline_index_file(rag, file_path: Path):
logging.error(f"Error indexing file {file_path.name}: {str(e)}") logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
async def pipeline_index_files(rag, file_paths: List[Path]): async def pipeline_index_files(rag, file_paths: List[Path]):
if not file_paths: if not file_paths:
return return
@@ -305,12 +326,14 @@ async def pipeline_index_files(rag, file_paths: List[Path]):
logging.error(f"Error indexing files: {str(e)}") logging.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
async def pipeline_index_texts(rag, texts: List[str]): async def pipeline_index_texts(rag, texts: List[str]):
if not texts: if not texts:
return return
await rag.apipeline_enqueue_documents(texts) await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents() await rag.apipeline_process_enqueue_documents()
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
@@ -320,6 +343,7 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
shutil.copyfileobj(file.file, buffer) shutil.copyfileobj(file.file, buffer)
return temp_path return temp_path
async def run_scanning_process(rag, doc_manager: DocumentManager): async def run_scanning_process(rag, doc_manager: DocumentManager):
"""Background task to scan and index documents""" """Background task to scan and index documents"""
try: try:
@@ -349,7 +373,10 @@ async def run_scanning_process(rag, doc_manager: DocumentManager):
async with progress_lock: async with progress_lock:
scan_progress["is_scanning"] = False scan_progress["is_scanning"] = False
def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None):
def create_document_routes(
rag, doc_manager: DocumentManager, api_key: Optional[str] = None
):
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)
@router.post("/scan", dependencies=[Depends(optional_api_key)]) @router.post("/scan", dependencies=[Depends(optional_api_key)])
@@ -437,8 +464,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) @router.post(
async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks): "/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
)
async def insert_text(
request: InsertTextRequest, background_tasks: BackgroundTasks
):
""" """
Insert text into the RAG system. Insert text into the RAG system.
@@ -466,8 +497,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) @router.post(
async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks): "/texts",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_texts(
request: InsertTextsRequest, background_tasks: BackgroundTasks
):
""" """
Insert multiple texts into the RAG system. Insert multiple texts into the RAG system.
@@ -495,8 +532,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) @router.post(
async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): "/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
)
async def insert_file(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
):
""" """
Insert a file directly into the RAG system. Insert a file directly into the RAG system.
@@ -532,8 +573,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) @router.post(
async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): "/file_batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
):
""" """
Process multiple files in batch mode. Process multiple files in batch mode.
@@ -587,7 +634,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) @router.delete(
"", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
)
async def clear_documents(): async def clear_documents():
""" """
Clear all documents from the RAG system. Clear all documents from the RAG system.
@@ -605,7 +654,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
rag.text_chunks = [] rag.text_chunks = []
rag.entities_vdb = None rag.entities_vdb = None
rag.relationships_vdb = None rag.relationships_vdb = None
return InsertResponse(status="success", message="All documents cleared successfully") return InsertResponse(
status="success", message="All documents cleared successfully"
)
except Exception as e: except Exception as e:
logging.error(f"Error DELETE /documents: {str(e)}") logging.error(f"Error DELETE /documents: {str(e)}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
@@ -651,8 +702,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
content_summary=doc_status.content_summary, content_summary=doc_status.content_summary,
content_length=doc_status.content_length, content_length=doc_status.content_length,
status=doc_status.status, status=doc_status.status,
created_at=DocStatusResponse.format_datetime(doc_status.created_at), created_at=DocStatusResponse.format_datetime(
updated_at=DocStatusResponse.format_datetime(doc_status.updated_at), doc_status.created_at
),
updated_at=DocStatusResponse.format_datetime(
doc_status.updated_at
),
chunks_count=doc_status.chunks_count, chunks_count=doc_status.chunks_count,
error=doc_status.error, error=doc_status.error,
metadata=doc_status.metadata, metadata=doc_status.metadata,

View File

@@ -4,12 +4,13 @@ This module contains all graph-related routes for the LightRAG API.
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency
router = APIRouter(tags=["graph"]) router = APIRouter(tags=["graph"])
def create_graph_routes(rag, api_key: Optional[str] = None): def create_graph_routes(rag, api_key: Optional[str] = None):
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)

View File

@@ -4,7 +4,6 @@ This module contains all query-related routes for the LightRAG API.
import json import json
import logging import logging
import traceback
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@@ -16,6 +15,7 @@ from ascii_colors import trace_exception
router = APIRouter(tags=["query"]) router = APIRouter(tags=["query"])
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str = Field( query: str = Field(
min_length=1, min_length=1,
@@ -131,15 +131,19 @@ class QueryRequest(BaseModel):
param.stream = is_stream param.stream = is_stream
return param return param
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str = Field( response: str = Field(
description="The generated response", description="The generated response",
) )
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)
@router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) @router.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest): async def query_text(request: QueryRequest):
""" """
Handle a POST request at the /query endpoint to process user queries using RAG capabilities. Handle a POST request at the /query endpoint to process user queries using RAG capabilities.

View File

@@ -7,6 +7,7 @@ from fastapi import HTTPException, Security
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
def get_api_key_dependency(api_key: Optional[str]): def get_api_key_dependency(api_key: Optional[str]):
""" """
Create an API key dependency for route protection. Create an API key dependency for route protection.

View File

@@ -1,6 +1,6 @@
future
aiohttp aiohttp
configparser configparser
future
# Basic modules # Basic modules
numpy numpy