Fix refactoring error on document handling
- Fix refactoring error on pipeline_index_file - Delete unsed func: scan_directory - Add type hints of rag for better maintainability - Refine comments for better understanding
This commit is contained in:
@@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Any
|
|||||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from lightrag import LightRAG
|
||||||
from lightrag.base import DocProcessingStatus, DocStatus
|
from lightrag.base import DocProcessingStatus, DocStatus
|
||||||
from ..utils_api import get_api_key_dependency
|
from ..utils_api import get_api_key_dependency
|
||||||
|
|
||||||
@@ -76,6 +77,20 @@ class DocStatusResponse(BaseModel):
|
|||||||
return dt
|
return dt
|
||||||
return dt.isoformat()
|
return dt.isoformat()
|
||||||
|
|
||||||
|
"""Response model for document status
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Document identifier
|
||||||
|
content_summary: Summary of document content
|
||||||
|
content_length: Length of document content
|
||||||
|
status: Current processing status
|
||||||
|
created_at: Creation timestamp (ISO format string)
|
||||||
|
updated_at: Last update timestamp (ISO format string)
|
||||||
|
chunks_count: Number of chunks (optional)
|
||||||
|
error: Error message if any (optional)
|
||||||
|
metadata: Additional metadata (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
content_summary: str
|
content_summary: str
|
||||||
content_length: int
|
content_length: int
|
||||||
@@ -112,6 +127,7 @@ class DocumentManager:
|
|||||||
self.input_dir.mkdir(parents=True, exist_ok=True)
|
self.input_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def scan_directory_for_new_files(self) -> List[Path]:
|
def scan_directory_for_new_files(self) -> List[Path]:
|
||||||
|
"""Scan input directory for new files"""
|
||||||
new_files = []
|
new_files = []
|
||||||
for ext in self.supported_extensions:
|
for ext in self.supported_extensions:
|
||||||
logging.info(f"Scanning for {ext} files in {self.input_dir}")
|
logging.info(f"Scanning for {ext} files in {self.input_dir}")
|
||||||
@@ -120,12 +136,12 @@ class DocumentManager:
|
|||||||
new_files.append(file_path)
|
new_files.append(file_path)
|
||||||
return new_files
|
return new_files
|
||||||
|
|
||||||
def scan_directory(self) -> List[Path]:
|
# def scan_directory(self) -> List[Path]:
|
||||||
new_files = []
|
# new_files = []
|
||||||
for ext in self.supported_extensions:
|
# for ext in self.supported_extensions:
|
||||||
for file_path in self.input_dir.rglob(f"*{ext}"):
|
# for file_path in self.input_dir.rglob(f"*{ext}"):
|
||||||
new_files.append(file_path)
|
# new_files.append(file_path)
|
||||||
return new_files
|
# return new_files
|
||||||
|
|
||||||
def mark_as_indexed(self, file_path: Path):
|
def mark_as_indexed(self, file_path: Path):
|
||||||
self.indexed_files.add(file_path)
|
self.indexed_files.add(file_path)
|
||||||
@@ -134,7 +150,16 @@ class DocumentManager:
|
|||||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||||
|
"""Add a file to the queue for processing
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rag: LightRAG instance
|
||||||
|
file_path: Path to the saved file
|
||||||
|
Returns:
|
||||||
|
bool: True if the file was successfully enqueued, False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = ""
|
content = ""
|
||||||
ext = file_path.suffix.lower()
|
ext = file_path.suffix.lower()
|
||||||
@@ -165,7 +190,9 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
|||||||
|
|
||||||
docx_file = BytesIO(file)
|
docx_file = BytesIO(file)
|
||||||
doc = Document(docx_file)
|
doc = Document(docx_file)
|
||||||
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
content = "\n".join(
|
||||||
|
[paragraph.text for paragraph in doc.paragraphs]
|
||||||
|
)
|
||||||
case ".pptx":
|
case ".pptx":
|
||||||
if not pm.is_installed("pptx"):
|
if not pm.is_installed("pptx"):
|
||||||
pm.install("pptx")
|
pm.install("pptx")
|
||||||
@@ -205,13 +232,19 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
|||||||
# Insert into the RAG queue
|
# Insert into the RAG queue
|
||||||
if content:
|
if content:
|
||||||
await rag.apipeline_enqueue_documents(content)
|
await rag.apipeline_enqueue_documents(content)
|
||||||
logging.info(f"Successfully fetched and enqueued file: {file_path.name}")
|
logging.info(
|
||||||
|
f"Successfully fetched and enqueued file: {file_path.name}"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logging.error(f"No content could be extracted from file: {file_path.name}")
|
logging.error(
|
||||||
|
f"No content could be extracted from file: {file_path.name}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
|
logging.error(
|
||||||
|
f"Error processing or enqueueing file {file_path.name}: {str(e)}"
|
||||||
|
)
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
if file_path.name.startswith(temp_prefix):
|
if file_path.name.startswith(temp_prefix):
|
||||||
@@ -222,7 +255,7 @@ async def pipeline_enqueue_file(rag, file_path: Path) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_index_file(rag, file_path: Path):
|
async def pipeline_index_file(rag: LightRAG, file_path: Path):
|
||||||
"""Index a file
|
"""Index a file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -230,90 +263,26 @@ async def pipeline_index_file(rag, file_path: Path):
|
|||||||
file_path: Path to the saved file
|
file_path: Path to the saved file
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
content = ""
|
if await pipeline_enqueue_file(file_path):
|
||||||
ext = file_path.suffix.lower()
|
|
||||||
|
|
||||||
file = None
|
|
||||||
async with aiofiles.open(file_path, "rb") as f:
|
|
||||||
file = await f.read()
|
|
||||||
|
|
||||||
# Process based on file type
|
|
||||||
match ext:
|
|
||||||
case ".txt" | ".md":
|
|
||||||
content = file.decode("utf-8")
|
|
||||||
case ".pdf":
|
|
||||||
if not pm.is_installed("pypdf2"):
|
|
||||||
pm.install("pypdf2")
|
|
||||||
from PyPDF2 import PdfReader # type: ignore
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
pdf_file = BytesIO(file)
|
|
||||||
reader = PdfReader(pdf_file)
|
|
||||||
for page in reader.pages:
|
|
||||||
content += page.extract_text() + "\n"
|
|
||||||
case ".docx":
|
|
||||||
if not pm.is_installed("docx"):
|
|
||||||
pm.install("docx")
|
|
||||||
from docx import Document
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
docx_file = BytesIO(file)
|
|
||||||
doc = Document(docx_file)
|
|
||||||
content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
|
||||||
case ".pptx":
|
|
||||||
if not pm.is_installed("pptx"):
|
|
||||||
pm.install("pptx")
|
|
||||||
from pptx import Presentation
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
pptx_file = BytesIO(file)
|
|
||||||
prs = Presentation(pptx_file)
|
|
||||||
for slide in prs.slides:
|
|
||||||
for shape in slide.shapes:
|
|
||||||
if hasattr(shape, "text"):
|
|
||||||
content += shape.text + "\n"
|
|
||||||
case ".xlsx":
|
|
||||||
if not pm.is_installed("openpyxl"):
|
|
||||||
pm.install("openpyxl")
|
|
||||||
from openpyxl import load_workbook
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
xlsx_file = BytesIO(file)
|
|
||||||
wb = load_workbook(xlsx_file)
|
|
||||||
for sheet in wb:
|
|
||||||
content += f"Sheet: {sheet.title}\n"
|
|
||||||
for row in sheet.iter_rows(values_only=True):
|
|
||||||
content += (
|
|
||||||
"\t".join(
|
|
||||||
str(cell) if cell is not None else "" for cell in row
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
content += "\n"
|
|
||||||
case _:
|
|
||||||
logging.error(
|
|
||||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Insert into the RAG queue
|
|
||||||
if content:
|
|
||||||
await rag.apipeline_enqueue_documents(content)
|
|
||||||
await rag.apipeline_process_enqueue_documents()
|
await rag.apipeline_process_enqueue_documents()
|
||||||
logging.info(f"Successfully indexed file: {file_path.name}")
|
|
||||||
else:
|
|
||||||
logging.error(f"No content could be extracted from file: {file_path.name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_index_files(rag, file_paths: List[Path]):
|
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
|
||||||
|
"""Index multiple files concurrently
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rag: LightRAG instance
|
||||||
|
file_paths: Paths to the files to index
|
||||||
|
"""
|
||||||
if not file_paths:
|
if not file_paths:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
enqueued = False
|
enqueued = False
|
||||||
|
|
||||||
if len(file_paths) == 1:
|
if len(file_paths) == 1:
|
||||||
enqueued = await pipeline_enqueue_file(rag, file_paths[0])
|
enqueued = await pipeline_enqueue_file(rag, file_paths[0])
|
||||||
else:
|
else:
|
||||||
@@ -327,7 +296,13 @@ async def pipeline_index_files(rag, file_paths: List[Path]):
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_index_texts(rag, texts: List[str]):
|
async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
||||||
|
"""Index a list of texts
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rag: LightRAG instance
|
||||||
|
texts: The texts to index
|
||||||
|
"""
|
||||||
if not texts:
|
if not texts:
|
||||||
return
|
return
|
||||||
await rag.apipeline_enqueue_documents(texts)
|
await rag.apipeline_enqueue_documents(texts)
|
||||||
@@ -335,16 +310,29 @@ async def pipeline_index_texts(rag, texts: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
||||||
|
"""Save the uploaded file to a temporary location
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: The uploaded file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: The path to the saved file
|
||||||
|
"""
|
||||||
|
# Generate unique filename to avoid conflicts
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
|
unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
|
||||||
|
|
||||||
|
# Create a temporary file to save the uploaded content
|
||||||
temp_path = input_dir / "temp" / unique_filename
|
temp_path = input_dir / "temp" / unique_filename
|
||||||
temp_path.parent.mkdir(exist_ok=True)
|
temp_path.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Save the file
|
||||||
with open(temp_path, "wb") as buffer:
|
with open(temp_path, "wb") as buffer:
|
||||||
shutil.copyfileobj(file.file, buffer)
|
shutil.copyfileobj(file.file, buffer)
|
||||||
return temp_path
|
return temp_path
|
||||||
|
|
||||||
|
|
||||||
async def run_scanning_process(rag, doc_manager: DocumentManager):
|
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||||
"""Background task to scan and index documents"""
|
"""Background task to scan and index documents"""
|
||||||
try:
|
try:
|
||||||
new_files = doc_manager.scan_directory_for_new_files()
|
new_files = doc_manager.scan_directory_for_new_files()
|
||||||
@@ -375,7 +363,7 @@ async def run_scanning_process(rag, doc_manager: DocumentManager):
|
|||||||
|
|
||||||
|
|
||||||
def create_document_routes(
|
def create_document_routes(
|
||||||
rag, doc_manager: DocumentManager, api_key: Optional[str] = None
|
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
||||||
):
|
):
|
||||||
optional_api_key = get_api_key_dependency(api_key)
|
optional_api_key = get_api_key_dependency(api_key)
|
||||||
|
|
||||||
@@ -388,9 +376,6 @@ def create_document_routes(
|
|||||||
and processes them. If a scanning process is already running, it returns a status indicating
|
and processes them. If a scanning process is already running, it returns a status indicating
|
||||||
that fact.
|
that fact.
|
||||||
|
|
||||||
Args:
|
|
||||||
background_tasks (BackgroundTasks): FastAPI background tasks handler
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the scanning status
|
dict: A dictionary containing the scanning status
|
||||||
"""
|
"""
|
||||||
@@ -402,6 +387,7 @@ def create_document_routes(
|
|||||||
scan_progress["indexed_count"] = 0
|
scan_progress["indexed_count"] = 0
|
||||||
scan_progress["progress"] = 0
|
scan_progress["progress"] = 0
|
||||||
|
|
||||||
|
# Start the scanning process in the background
|
||||||
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
||||||
return {"status": "scanning_started"}
|
return {"status": "scanning_started"}
|
||||||
|
|
||||||
@@ -453,6 +439,7 @@ def create_document_routes(
|
|||||||
with open(file_path, "wb") as buffer:
|
with open(file_path, "wb") as buffer:
|
||||||
shutil.copyfileobj(file.file, buffer)
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
|
||||||
|
# Add to background tasks
|
||||||
background_tasks.add_task(pipeline_index_file, rag, file_path)
|
background_tasks.add_task(pipeline_index_file, rag, file_path)
|
||||||
|
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
@@ -562,6 +549,8 @@ def create_document_routes(
|
|||||||
)
|
)
|
||||||
|
|
||||||
temp_path = await save_temp_file(doc_manager.input_dir, file)
|
temp_path = await save_temp_file(doc_manager.input_dir, file)
|
||||||
|
|
||||||
|
# Add to background tasks
|
||||||
background_tasks.add_task(pipeline_index_file, rag, temp_path)
|
background_tasks.add_task(pipeline_index_file, rag, temp_path)
|
||||||
|
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
@@ -606,6 +595,7 @@ def create_document_routes(
|
|||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
if doc_manager.is_supported_file(file.filename):
|
if doc_manager.is_supported_file(file.filename):
|
||||||
|
# Create a temporary file to save the uploaded content
|
||||||
temp_files.append(await save_temp_file(doc_manager.input_dir, file))
|
temp_files.append(await save_temp_file(doc_manager.input_dir, file))
|
||||||
inserted_count += 1
|
inserted_count += 1
|
||||||
else:
|
else:
|
||||||
@@ -614,6 +604,7 @@ def create_document_routes(
|
|||||||
if temp_files:
|
if temp_files:
|
||||||
background_tasks.add_task(pipeline_index_files, rag, temp_files)
|
background_tasks.add_task(pipeline_index_files, rag, temp_files)
|
||||||
|
|
||||||
|
# Prepare status message
|
||||||
if inserted_count == len(files):
|
if inserted_count == len(files):
|
||||||
status = "success"
|
status = "success"
|
||||||
status_message = f"Successfully inserted all {inserted_count} documents"
|
status_message = f"Successfully inserted all {inserted_count} documents"
|
||||||
|
Reference in New Issue
Block a user