Fit linting
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -62,4 +62,4 @@ lightrag-dev/
|
||||
gui/
|
||||
|
||||
# unit-test files
|
||||
test_*
|
||||
test_*
|
||||
|
@@ -4,7 +4,6 @@ LightRAG FastAPI Server
|
||||
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
Depends,
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
@@ -14,7 +13,7 @@ import os
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import logging
|
||||
import argparse
|
||||
from typing import Optional, Dict
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
import configparser
|
||||
from ascii_colors import ASCIIColors
|
||||
@@ -73,6 +72,7 @@ scan_progress: Dict = {
|
||||
# Lock for thread-safe operations
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
default_hosts = {
|
||||
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||
@@ -624,7 +624,9 @@ def create_app(args):
|
||||
scan_progress["indexed_count"] = 0
|
||||
scan_progress["progress"] = 0
|
||||
# Create background task
|
||||
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
||||
task = asyncio.create_task(
|
||||
run_scanning_process(rag, doc_manager)
|
||||
)
|
||||
app.state.background_tasks.add(task)
|
||||
task.add_done_callback(app.state.background_tasks.discard)
|
||||
ASCIIColors.info(
|
||||
@@ -876,8 +878,12 @@ def create_app(args):
|
||||
# Webui mount webui/index.html
|
||||
static_dir = Path(__file__).parent / "webui"
|
||||
static_dir.mkdir(exist_ok=True)
|
||||
app.mount("/webui", StaticFiles(directory=static_dir, html=True, check_dir=True), name="webui")
|
||||
|
||||
app.mount(
|
||||
"/webui",
|
||||
StaticFiles(directory=static_dir, html=True, check_dir=True),
|
||||
name="webui",
|
||||
)
|
||||
|
||||
@app.get("/webui/")
|
||||
async def webui_root():
|
||||
return FileResponse(static_dir / "index.html")
|
||||
|
@@ -14,9 +14,7 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from ..utils_api import get_api_key_dependency
|
||||
@@ -39,6 +37,7 @@ progress_lock = asyncio.Lock()
|
||||
# Temporary file prefix
|
||||
temp_prefix = "__tmp__"
|
||||
|
||||
|
||||
class InsertTextRequest(BaseModel):
|
||||
text: str = Field(
|
||||
min_length=1,
|
||||
@@ -50,6 +49,7 @@ class InsertTextRequest(BaseModel):
|
||||
def strip_after(cls, text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
class InsertTextsRequest(BaseModel):
|
||||
texts: list[str] = Field(
|
||||
min_length=1,
|
||||
@@ -61,10 +61,12 @@ class InsertTextsRequest(BaseModel):
|
||||
def strip_after(cls, texts: list[str]) -> list[str]:
|
||||
return [text.strip() for text in texts]
|
||||
|
||||
|
||||
class InsertResponse(BaseModel):
|
||||
status: str = Field(description="Status of the operation")
|
||||
message: str = Field(description="Message describing the operation result")
|
||||
|
||||
|
||||
class DocStatusResponse(BaseModel):
|
||||
@staticmethod
|
||||
def format_datetime(dt: Any) -> Optional[str]:
|
||||
@@ -84,9 +86,11 @@ class DocStatusResponse(BaseModel):
|
||||
error: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DocsStatusesResponse(BaseModel):
|
||||
statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
|
||||
|
||||
|
||||
class DocumentManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -129,6 +133,7 @@ class DocumentManager:
|
||||
def is_supported_file(self, filename: str) -> bool:
|
||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||
|
||||
|
||||
async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
try:
|
||||
content = ""
|
||||
@@ -145,7 +150,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
case ".pdf":
|
||||
if not pm.is_installed("pypdf2"):
|
||||
pm.install("pypdf2")
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(file)
|
||||
@@ -184,10 +189,17 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
for sheet in wb:
|
||||
content += f"Sheet: {sheet.title}\n"
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
|
||||
content += (
|
||||
"\t".join(
|
||||
str(cell) if cell is not None else "" for cell in row
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
content += "\n"
|
||||
case _:
|
||||
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
|
||||
logging.error(
|
||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||
)
|
||||
return False
|
||||
|
||||
# Insert into the RAG queue
|
||||
@@ -209,6 +221,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def pipeline_index_file(rag, file_path: Path):
|
||||
"""Index a file
|
||||
|
||||
@@ -231,7 +244,7 @@ async def pipeline_index_file(rag, file_path: Path):
|
||||
case ".pdf":
|
||||
if not pm.is_installed("pypdf2"):
|
||||
pm.install("pypdf2")
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(file)
|
||||
@@ -270,10 +283,17 @@ async def pipeline_index_file(rag, file_path: Path):
|
||||
for sheet in wb:
|
||||
content += f"Sheet: {sheet.title}\n"
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
content += "\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
|
||||
content += (
|
||||
"\t".join(
|
||||
str(cell) if cell is not None else "" for cell in row
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
content += "\n"
|
||||
case _:
|
||||
logging.error(f"Unsupported file type: {file_path.name} (extension {ext})")
|
||||
logging.error(
|
||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||
)
|
||||
return
|
||||
|
||||
# Insert into the RAG queue
|
||||
@@ -288,6 +308,7 @@ async def pipeline_index_file(rag, file_path: Path):
|
||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def pipeline_index_files(rag, file_paths: List[Path]):
|
||||
if not file_paths:
|
||||
return
|
||||
@@ -305,12 +326,14 @@ async def pipeline_index_files(rag, file_paths: List[Path]):
|
||||
logging.error(f"Error indexing files: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def pipeline_index_texts(rag, texts: List[str]):
|
||||
if not texts:
|
||||
return
|
||||
await rag.apipeline_enqueue_documents(texts)
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
|
||||
|
||||
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
|
||||
@@ -320,6 +343,7 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
return temp_path
|
||||
|
||||
|
||||
async def run_scanning_process(rag, doc_manager: DocumentManager):
|
||||
"""Background task to scan and index documents"""
|
||||
try:
|
||||
@@ -349,7 +373,10 @@ async def run_scanning_process(rag, doc_manager: DocumentManager):
|
||||
async with progress_lock:
|
||||
scan_progress["is_scanning"] = False
|
||||
|
||||
def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[str] = None):
|
||||
|
||||
def create_document_routes(
|
||||
rag, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||
):
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
@router.post("/scan", dependencies=[Depends(optional_api_key)])
|
||||
@@ -437,8 +464,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_text(request: InsertTextRequest, background_tasks: BackgroundTasks):
|
||||
@router.post(
|
||||
"/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||
)
|
||||
async def insert_text(
|
||||
request: InsertTextRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Insert text into the RAG system.
|
||||
|
||||
@@ -466,8 +497,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/texts", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_texts(request: InsertTextsRequest, background_tasks: BackgroundTasks):
|
||||
@router.post(
|
||||
"/texts",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
)
|
||||
async def insert_texts(
|
||||
request: InsertTextsRequest, background_tasks: BackgroundTasks
|
||||
):
|
||||
"""
|
||||
Insert multiple texts into the RAG system.
|
||||
|
||||
@@ -495,8 +532,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
|
||||
@router.post(
|
||||
"/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||
)
|
||||
async def insert_file(
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
):
|
||||
"""
|
||||
Insert a file directly into the RAG system.
|
||||
|
||||
@@ -532,8 +573,14 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/file_batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
async def insert_batch(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
|
||||
@router.post(
|
||||
"/file_batch",
|
||||
response_model=InsertResponse,
|
||||
dependencies=[Depends(optional_api_key)],
|
||||
)
|
||||
async def insert_batch(
|
||||
background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
|
||||
):
|
||||
"""
|
||||
Process multiple files in batch mode.
|
||||
|
||||
@@ -587,7 +634,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
||||
logging.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete("", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
|
||||
@router.delete(
|
||||
"", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]
|
||||
)
|
||||
async def clear_documents():
|
||||
"""
|
||||
Clear all documents from the RAG system.
|
||||
@@ -605,7 +654,9 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
||||
rag.text_chunks = []
|
||||
rag.entities_vdb = None
|
||||
rag.relationships_vdb = None
|
||||
return InsertResponse(status="success", message="All documents cleared successfully")
|
||||
return InsertResponse(
|
||||
status="success", message="All documents cleared successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error DELETE /documents: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
@@ -651,8 +702,12 @@ def create_document_routes(rag, doc_manager: DocumentManager, api_key: Optional[
|
||||
content_summary=doc_status.content_summary,
|
||||
content_length=doc_status.content_length,
|
||||
status=doc_status.status,
|
||||
created_at=DocStatusResponse.format_datetime(doc_status.created_at),
|
||||
updated_at=DocStatusResponse.format_datetime(doc_status.updated_at),
|
||||
created_at=DocStatusResponse.format_datetime(
|
||||
doc_status.created_at
|
||||
),
|
||||
updated_at=DocStatusResponse.format_datetime(
|
||||
doc_status.updated_at
|
||||
),
|
||||
chunks_count=doc_status.chunks_count,
|
||||
error=doc_status.error,
|
||||
metadata=doc_status.metadata,
|
||||
|
@@ -4,12 +4,13 @@ This module contains all graph-related routes for the LightRAG API.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from ..utils_api import get_api_key_dependency
|
||||
|
||||
router = APIRouter(tags=["graph"])
|
||||
|
||||
|
||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
|
@@ -4,7 +4,6 @@ This module contains all query-related routes for the LightRAG API.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
@@ -16,6 +15,7 @@ from ascii_colors import trace_exception
|
||||
|
||||
router = APIRouter(tags=["query"])
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str = Field(
|
||||
min_length=1,
|
||||
@@ -131,15 +131,19 @@ class QueryRequest(BaseModel):
|
||||
param.stream = is_stream
|
||||
return param
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
response: str = Field(
|
||||
description="The generated response",
|
||||
)
|
||||
|
||||
|
||||
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
@router.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
|
||||
@router.post(
|
||||
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
|
||||
)
|
||||
async def query_text(request: QueryRequest):
|
||||
"""
|
||||
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
||||
|
@@ -7,6 +7,7 @@ from fastapi import HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_api_key_dependency(api_key: Optional[str]):
|
||||
"""
|
||||
Create an API key dependency for route protection.
|
||||
|
@@ -1,6 +1,6 @@
|
||||
future
|
||||
aiohttp
|
||||
configparser
|
||||
future
|
||||
|
||||
# Basic modules
|
||||
numpy
|
||||
|
Reference in New Issue
Block a user