Fix refactoring error on document handling

- Fix refactoring error on pipeline_index_file
- Delete unsed func: scan_directory
- Add type hints of rag for better maintainability
- Refine comments for better understanding
This commit is contained in:
yangdx
2025-02-20 14:30:41 +08:00
parent 62e1fe5df2
commit 82a4cb3e79

View File

@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Any
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency from ..utils_api import get_api_key_dependency
@@ -76,6 +77,20 @@ class DocStatusResponse(BaseModel):
return dt return dt
return dt.isoformat() return dt.isoformat()
"""Response model for document status
Attributes:
id: Document identifier
content_summary: Summary of document content
content_length: Length of document content
status: Current processing status
created_at: Creation timestamp (ISO format string)
updated_at: Last update timestamp (ISO format string)
chunks_count: Number of chunks (optional)
error: Error message if any (optional)
metadata: Additional metadata (optional)
"""
id: str id: str
content_summary: str content_summary: str
content_length: int content_length: int
@@ -112,6 +127,7 @@ class DocumentManager:
self.input_dir.mkdir(parents=True, exist_ok=True) self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory_for_new_files(self) -> List[Path]: def scan_directory_for_new_files(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = [] new_files = []
for ext in self.supported_extensions: for ext in self.supported_extensions:
logging.info(f"Scanning for {ext} files in {self.input_dir}") logging.info(f"Scanning for {ext} files in {self.input_dir}")
@@ -120,12 +136,12 @@ class DocumentManager:
new_files.append(file_path) new_files.append(file_path)
return new_files return new_files
def scan_directory(self) -> List[Path]: # def scan_directory(self) -> List[Path]:
new_files = [] # new_files = []
for ext in self.supported_extensions: # for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"): # for file_path in self.input_dir.rglob(f"*{ext}"):
new_files.append(file_path) # new_files.append(file_path)
return new_files # return new_files
def mark_as_indexed(self, file_path: Path): def mark_as_indexed(self, file_path: Path):
self.indexed_files.add(file_path) self.indexed_files.add(file_path)
@@ -134,7 +150,16 @@ class DocumentManager:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions) return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
async def pipeline_enqueue_file(rag, file_path: Path) -> bool: async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
"""Add a file to the queue for processing
Args:
rag: LightRAG instance
file_path: Path to the saved file
Returns:
bool: True if the file was successfully enqueued, False otherwise
"""
try: try:
content = "" content = ""
ext = file_path.suffix.lower() ext = file_path.suffix.lower()
@@ -165,7 +190,9 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
docx_file = BytesIO(file) docx_file = BytesIO(file)
doc = Document(docx_file) doc = Document(docx_file)
content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
)
case ".pptx": case ".pptx":
if not pm.is_installed("pptx"): if not pm.is_installed("pptx"):
pm.install("pptx") pm.install("pptx")
@@ -205,13 +232,19 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
# Insert into the RAG queue # Insert into the RAG queue
if content: if content:
await rag.apipeline_enqueue_documents(content) await rag.apipeline_enqueue_documents(content)
logging.info(f"Successfully fetched and enqueued file: {file_path.name}") logging.info(
f"Successfully fetched and enqueued file: {file_path.name}"
)
return True return True
else: else:
logging.error(f"No content could be extracted from file: {file_path.name}") logging.error(
f"No content could be extracted from file: {file_path.name}"
)
except Exception as e: except Exception as e:
logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") logging.error(
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
)
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
finally: finally:
if file_path.name.startswith(temp_prefix): if file_path.name.startswith(temp_prefix):
@@ -222,7 +255,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
return False return False
async def pipeline_index_file(rag, file_path: Path): async def pipeline_index_file(rag: LightRAG, file_path: Path):
"""Index a file """Index a file
Args: Args:
@@ -230,90 +263,26 @@ async def pipeline_index_file(rag, file_path: Path):
file_path: Path to the saved file file_path: Path to the saved file
""" """
try: try:
content = "" if await pipeline_enqueue_file(file_path):
ext = file_path.suffix.lower()
file = None
async with aiofiles.open(file_path, "rb") as f:
file = await f.read()
# Process based on file type
match ext:
case ".txt" | ".md":
content = file.decode("utf-8")
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
reader = PdfReader(pdf_file)
for page in reader.pages:
content += page.extract_text() + "\n"
case ".docx":
if not pm.is_installed("docx"):
pm.install("docx")
from docx import Document
from io import BytesIO
docx_file = BytesIO(file)
doc = Document(docx_file)
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
case ".pptx":
if not pm.is_installed("pptx"):
pm.install("pptx")
from pptx import Presentation
from io import BytesIO
pptx_file = BytesIO(file)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
case ".xlsx":
if not pm.is_installed("openpyxl"):
pm.install("openpyxl")
from openpyxl import load_workbook
from io import BytesIO
xlsx_file = BytesIO(file)
wb = load_workbook(xlsx_file)
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += (
"\t".join(
str(cell) if cell is not None else "" for cell in row
)
+ "\n"
)
content += "\n"
case _:
logging.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return
# Insert into the RAG queue
if content:
await rag.apipeline_enqueue_documents(content)
await rag.apipeline_process_enqueue_documents() await rag.apipeline_process_enqueue_documents()
logging.info(f"Successfully indexed file: {file_path.name}")
else:
logging.error(f"No content could be extracted from file: {file_path.name}")
except Exception as e: except Exception as e:
logging.error(f"Error indexing file {file_path.name}: {str(e)}") logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
async def pipeline_index_files(rag, file_paths: List[Path]): async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
"""Index multiple files concurrently
Args:
rag: LightRAG instance
file_paths: Paths to the files to index
"""
if not file_paths: if not file_paths:
return return
try: try:
enqueued = False enqueued = False
if len(file_paths) == 1: if len(file_paths) == 1:
enqueued = await pipeline_enqueue_file(rag, file_paths[0]) enqueued = await pipeline_enqueue_file(rag, file_paths[0])
else: else:
@@ -327,7 +296,13 @@ async def pipeline_index_files(rag, file_paths: List[Path]):
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
async def pipeline_index_texts(rag, texts: List[str]): async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
"""Index a list of texts
Args:
rag: LightRAG instance
texts: The texts to index
"""
if not texts: if not texts:
return return
await rag.apipeline_enqueue_documents(texts) await rag.apipeline_enqueue_documents(texts)
@@ -335,16 +310,29 @@ async def pipeline_index_texts(rag, texts: List[str]):
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
"""Save the uploaded file to a temporary location
Args:
file: The uploaded file
Returns:
Path: The path to the saved file
"""
# Generate unique filename to avoid conflicts
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}" unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
# Create a temporary file to save the uploaded content
temp_path = input_dir / "temp" / unique_filename temp_path = input_dir / "temp" / unique_filename
temp_path.parent.mkdir(exist_ok=True) temp_path.parent.mkdir(exist_ok=True)
# Save the file
with open(temp_path, "wb") as buffer: with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer) shutil.copyfileobj(file.file, buffer)
return temp_path return temp_path
async def run_scanning_process(rag, doc_manager: DocumentManager): async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
"""Background task to scan and index documents""" """Background task to scan and index documents"""
try: try:
new_files = doc_manager.scan_directory_for_new_files() new_files = doc_manager.scan_directory_for_new_files()
@@ -375,7 +363,7 @@ async def run_scanning_process(rag, doc_manager: DocumentManager):
def create_document_routes( def create_document_routes(
rag, doc_manager: DocumentManager, api_key: Optional[str] = None rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
): ):
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)
@@ -388,9 +376,6 @@ def create_document_routes(
and processes them. If a scanning process is already running, it returns a status indicating and processes them. If a scanning process is already running, it returns a status indicating
that fact. that fact.
Args:
background_tasks (BackgroundTasks): FastAPI background tasks handler
Returns: Returns:
dict: A dictionary containing the scanning status dict: A dictionary containing the scanning status
""" """
@@ -402,6 +387,7 @@ def create_document_routes(
scan_progress["indexed_count"] = 0 scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0 scan_progress["progress"] = 0
# Start the scanning process in the background
background_tasks.add_task(run_scanning_process, rag, doc_manager) background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"} return {"status": "scanning_started"}
@@ -453,6 +439,7 @@ def create_document_routes(
with open(file_path, "wb") as buffer: with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer) shutil.copyfileobj(file.file, buffer)
# Add to background tasks
background_tasks.add_task(pipeline_index_file, rag, file_path) background_tasks.add_task(pipeline_index_file, rag, file_path)
return InsertResponse( return InsertResponse(
@@ -562,6 +549,8 @@ def create_document_routes(
) )
temp_path = await save_temp_file(doc_manager.input_dir, file) temp_path = await save_temp_file(doc_manager.input_dir, file)
# Add to background tasks
background_tasks.add_task(pipeline_index_file, rag, temp_path) background_tasks.add_task(pipeline_index_file, rag, temp_path)
return InsertResponse( return InsertResponse(
@@ -606,6 +595,7 @@ def create_document_routes(
for file in files: for file in files:
if doc_manager.is_supported_file(file.filename): if doc_manager.is_supported_file(file.filename):
# Create a temporary file to save the uploaded content
temp_files.append(await save_temp_file(doc_manager.input_dir, file)) temp_files.append(await save_temp_file(doc_manager.input_dir, file))
inserted_count += 1 inserted_count += 1
else: else:
@@ -614,6 +604,7 @@ def create_document_routes(
if temp_files: if temp_files:
background_tasks.add_task(pipeline_index_files, rag, temp_files) background_tasks.add_task(pipeline_index_files, rag, temp_files)
# Prepare status message
if inserted_count == len(files): if inserted_count == len(files):
status = "success" status = "success"
status_message = f"Successfully inserted all {inserted_count} documents" status_message = f"Successfully inserted all {inserted_count} documents"