Fix refactoring error on document handling
- Fix refactoring error on pipeline_index_file - Delete unsed func: scan_directory - Add type hints of rag for better maintainability - Refine comments for better understanding
This commit is contained in:
@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Any
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from lightrag import LightRAG
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from ..utils_api import get_api_key_dependency
|
||||
|
||||
@@ -76,6 +77,20 @@ class DocStatusResponse(BaseModel):
|
||||
return dt
|
||||
return dt.isoformat()
|
||||
|
||||
"""Response model for document status
|
||||
|
||||
Attributes:
|
||||
id: Document identifier
|
||||
content_summary: Summary of document content
|
||||
content_length: Length of document content
|
||||
status: Current processing status
|
||||
created_at: Creation timestamp (ISO format string)
|
||||
updated_at: Last update timestamp (ISO format string)
|
||||
chunks_count: Number of chunks (optional)
|
||||
error: Error message if any (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
"""
|
||||
|
||||
id: str
|
||||
content_summary: str
|
||||
content_length: int
|
||||
@@ -112,6 +127,7 @@ class DocumentManager:
|
||||
self.input_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def scan_directory_for_new_files(self) -> List[Path]:
|
||||
"""Scan input directory for new files"""
|
||||
new_files = []
|
||||
for ext in self.supported_extensions:
|
||||
logging.info(f"Scanning for {ext} files in {self.input_dir}")
|
||||
@@ -120,12 +136,12 @@ class DocumentManager:
|
||||
new_files.append(file_path)
|
||||
return new_files
|
||||
|
||||
def scan_directory(self) -> List[Path]:
|
||||
new_files = []
|
||||
for ext in self.supported_extensions:
|
||||
for file_path in self.input_dir.rglob(f"*{ext}"):
|
||||
new_files.append(file_path)
|
||||
return new_files
|
||||
# def scan_directory(self) -> List[Path]:
|
||||
# new_files = []
|
||||
# for ext in self.supported_extensions:
|
||||
# for file_path in self.input_dir.rglob(f"*{ext}"):
|
||||
# new_files.append(file_path)
|
||||
# return new_files
|
||||
|
||||
def mark_as_indexed(self, file_path: Path):
|
||||
self.indexed_files.add(file_path)
|
||||
@@ -134,7 +150,16 @@ class DocumentManager:
|
||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||
|
||||
|
||||
async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
"""Add a file to the queue for processing
|
||||
|
||||
Args:
|
||||
rag: LightRAG instance
|
||||
file_path: Path to the saved file
|
||||
Returns:
|
||||
bool: True if the file was successfully enqueued, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
content = ""
|
||||
ext = file_path.suffix.lower()
|
||||
@@ -165,7 +190,9 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
|
||||
docx_file = BytesIO(file)
|
||||
doc = Document(docx_file)
|
||||
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
content = "\n".join(
|
||||
[paragraph.text for paragraph in doc.paragraphs]
|
||||
)
|
||||
case ".pptx":
|
||||
if not pm.is_installed("pptx"):
|
||||
pm.install("pptx")
|
||||
@@ -205,13 +232,19 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
# Insert into the RAG queue
|
||||
if content:
|
||||
await rag.apipeline_enqueue_documents(content)
|
||||
logging.info(f"Successfully fetched and enqueued file: {file_path.name}")
|
||||
logging.info(
|
||||
f"Successfully fetched and enqueued file: {file_path.name}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logging.error(f"No content could be extracted from file: {file_path.name}")
|
||||
logging.error(
|
||||
f"No content could be extracted from file: {file_path.name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
|
||||
logging.error(
|
||||
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
|
||||
)
|
||||
logging.error(traceback.format_exc())
|
||||
finally:
|
||||
if file_path.name.startswith(temp_prefix):
|
||||
@@ -222,7 +255,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def pipeline_index_file(rag, file_path: Path):
|
||||
async def pipeline_index_file(rag: LightRAG, file_path: Path):
|
||||
"""Index a file
|
||||
|
||||
Args:
|
||||
@@ -230,90 +263,26 @@ async def pipeline_index_file(rag, file_path: Path):
|
||||
file_path: Path to the saved file
|
||||
"""
|
||||
try:
|
||||
content = ""
|
||||
ext = file_path.suffix.lower()
|
||||
|
||||
file = None
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
file = await f.read()
|
||||
|
||||
# Process based on file type
|
||||
match ext:
|
||||
case ".txt" | ".md":
|
||||
content = file.decode("utf-8")
|
||||
case ".pdf":
|
||||
if not pm.is_installed("pypdf2"):
|
||||
pm.install("pypdf2")
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(file)
|
||||
reader = PdfReader(pdf_file)
|
||||
for page in reader.pages:
|
||||
content += page.extract_text() + "\n"
|
||||
case ".docx":
|
||||
if not pm.is_installed("docx"):
|
||||
pm.install("docx")
|
||||
from docx import Document
|
||||
from io import BytesIO
|
||||
|
||||
docx_file = BytesIO(file)
|
||||
doc = Document(docx_file)
|
||||
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
case ".pptx":
|
||||
if not pm.is_installed("pptx"):
|
||||
pm.install("pptx")
|
||||
from pptx import Presentation
|
||||
from io import BytesIO
|
||||
|
||||
pptx_file = BytesIO(file)
|
||||
prs = Presentation(pptx_file)
|
||||
for slide in prs.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
content += shape.text + "\n"
|
||||
case ".xlsx":
|
||||
if not pm.is_installed("openpyxl"):
|
||||
pm.install("openpyxl")
|
||||
from openpyxl import load_workbook
|
||||
from io import BytesIO
|
||||
|
||||
xlsx_file = BytesIO(file)
|
||||
wb = load_workbook(xlsx_file)
|
||||
for sheet in wb:
|
||||
content += f"Sheet: {sheet.title}\n"
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
content += (
|
||||
"\t".join(
|
||||
str(cell) if cell is not None else "" for cell in row
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
content += "\n"
|
||||
case _:
|
||||
logging.error(
|
||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||
)
|
||||
return
|
||||
|
||||
# Insert into the RAG queue
|
||||
if content:
|
||||
await rag.apipeline_enqueue_documents(content)
|
||||
if await pipeline_enqueue_file(file_path):
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
logging.info(f"Successfully indexed file: {file_path.name}")
|
||||
else:
|
||||
logging.error(f"No content could be extracted from file: {file_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def pipeline_index_files(rag, file_paths: List[Path]):
|
||||
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
|
||||
"""Index multiple files concurrently
|
||||
|
||||
Args:
|
||||
rag: LightRAG instance
|
||||
file_paths: Paths to the files to index
|
||||
"""
|
||||
if not file_paths:
|
||||
return
|
||||
try:
|
||||
enqueued = False
|
||||
|
||||
if len(file_paths) == 1:
|
||||
enqueued = await pipeline_enqueue_file(rag, file_paths[0])
|
||||
else:
|
||||
@@ -327,7 +296,13 @@ async def pipeline_index_files(rag, file_paths: List[Path]):
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def pipeline_index_texts(rag, texts: List[str]):
|
||||
async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
||||
"""Index a list of texts
|
||||
|
||||
Args:
|
||||
rag: LightRAG instance
|
||||
texts: The texts to index
|
||||
"""
|
||||
if not texts:
|
||||
return
|
||||
await rag.apipeline_enqueue_documents(texts)
|
||||
@@ -335,16 +310,29 @@ async def pipeline_index_texts(rag, texts: List[str]):
|
||||
|
||||
|
||||
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||
"""Save the uploaded file to a temporary location
|
||||
|
||||
Args:
|
||||
file: The uploaded file
|
||||
|
||||
Returns:
|
||||
Path: The path to the saved file
|
||||
"""
|
||||
# Generate unique filename to avoid conflicts
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
|
||||
|
||||
# Create a temporary file to save the uploaded content
|
||||
temp_path = input_dir / "temp" / unique_filename
|
||||
temp_path.parent.mkdir(exist_ok=True)
|
||||
|
||||
# Save the file
|
||||
with open(temp_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
return temp_path
|
||||
|
||||
|
||||
async def run_scanning_process(rag, doc_manager: DocumentManager):
|
||||
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
"""Background task to scan and index documents"""
|
||||
try:
|
||||
new_files = doc_manager.scan_directory_for_new_files()
|
||||
@@ -375,7 +363,7 @@ async def run_scanning_process(rag, doc_manager: DocumentManager):
|
||||
|
||||
|
||||
def create_document_routes(
|
||||
rag, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||
):
|
||||
optional_api_key = get_api_key_dependency(api_key)
|
||||
|
||||
@@ -388,9 +376,6 @@ def create_document_routes(
|
||||
and processes them. If a scanning process is already running, it returns a status indicating
|
||||
that fact.
|
||||
|
||||
Args:
|
||||
background_tasks (BackgroundTasks): FastAPI background tasks handler
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the scanning status
|
||||
"""
|
||||
@@ -402,6 +387,7 @@ def create_document_routes(
|
||||
scan_progress["indexed_count"] = 0
|
||||
scan_progress["progress"] = 0
|
||||
|
||||
# Start the scanning process in the background
|
||||
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
||||
return {"status": "scanning_started"}
|
||||
|
||||
@@ -453,6 +439,7 @@ def create_document_routes(
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
# Add to background tasks
|
||||
background_tasks.add_task(pipeline_index_file, rag, file_path)
|
||||
|
||||
return InsertResponse(
|
||||
@@ -562,6 +549,8 @@ def create_document_routes(
|
||||
)
|
||||
|
||||
temp_path = await save_temp_file(doc_manager.input_dir, file)
|
||||
|
||||
# Add to background tasks
|
||||
background_tasks.add_task(pipeline_index_file, rag, temp_path)
|
||||
|
||||
return InsertResponse(
|
||||
@@ -606,6 +595,7 @@ def create_document_routes(
|
||||
|
||||
for file in files:
|
||||
if doc_manager.is_supported_file(file.filename):
|
||||
# Create a temporary file to save the uploaded content
|
||||
temp_files.append(await save_temp_file(doc_manager.input_dir, file))
|
||||
inserted_count += 1
|
||||
else:
|
||||
@@ -614,6 +604,7 @@ def create_document_routes(
|
||||
if temp_files:
|
||||
background_tasks.add_task(pipeline_index_files, rag, temp_files)
|
||||
|
||||
# Prepare status message
|
||||
if inserted_count == len(files):
|
||||
status = "success"
|
||||
status_message = f"Successfully inserted all {inserted_count} documents"
|
||||
|
Reference in New Issue
Block a user