Fixed linting

This commit is contained in:
Saifeddine ALOUI
2024-12-19 11:44:01 +01:00
parent 81ba55df7b
commit fe6ebfa995
6 changed files with 238 additions and 157 deletions

View File

@@ -169,4 +169,3 @@ This project is licensed under the MIT License - see the LICENSE file for detail
- Built with [FastAPI](https://fastapi.tiangolo.com/) - Built with [FastAPI](https://fastapi.tiangolo.com/)
- Uses [LightRAG](https://github.com/HKUDS/LightRAG) for document processing - Uses [LightRAG](https://github.com/HKUDS/LightRAG) for document processing
- Powered by [OpenAI](https://openai.com/) for language model inference - Powered by [OpenAI](https://openai.com/) for language model inference

View File

@@ -1,8 +1,5 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
import asyncio
import os
import logging import logging
import argparse import argparse
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
@@ -13,7 +10,8 @@ from enum import Enum
from pathlib import Path from pathlib import Path
import shutil import shutil
import aiofiles import aiofiles
from ascii_colors import ASCIIColors, trace_exception from ascii_colors import trace_exception
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -21,41 +19,80 @@ def parse_args():
) )
# Server configuration # Server configuration
parser.add_argument('--host', default='0.0.0.0', help='Server host (default: 0.0.0.0)') parser.add_argument(
parser.add_argument('--port', type=int, default=9621, help='Server port (default: 9621)') "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration # Directory configuration
parser.add_argument('--working-dir', default='./rag_storage', parser.add_argument(
help='Working directory for RAG storage (default: ./rag_storage)') "--working-dir",
parser.add_argument('--input-dir', default='./inputs', default="./rag_storage",
help='Directory containing input documents (default: ./inputs)') help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# Model configuration # Model configuration
parser.add_argument('--model', default='mistral-nemo:latest', help='LLM model name (default: mistral-nemo:latest)') parser.add_argument(
parser.add_argument('--embedding-model', default='bge-m3:latest', "--model",
help='Embedding model name (default: bge-m3:latest)') default="mistral-nemo:latest",
parser.add_argument('--ollama-host', default='http://localhost:11434', help="LLM model name (default: mistral-nemo:latest)",
help='Ollama host URL (default: http://localhost:11434)') )
parser.add_argument(
"--embedding-model",
default="bge-m3:latest",
help="Embedding model name (default: bge-m3:latest)",
)
parser.add_argument(
"--ollama-host",
default="http://localhost:11434",
help="Ollama host URL (default: http://localhost:11434)",
)
# RAG configuration # RAG configuration
parser.add_argument('--max-async', type=int, default=4, help='Maximum async operations (default: 4)') parser.add_argument(
parser.add_argument('--max-tokens', type=int, default=32768, help='Maximum token size (default: 32768)') "--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
parser.add_argument('--embedding-dim', type=int, default=1024, )
help='Embedding dimensions (default: 1024)') parser.add_argument(
parser.add_argument('--max-embed-tokens', type=int, default=8192, "--max-tokens",
help='Maximum embedding token size (default: 8192)') type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=1024,
help="Embedding dimensions (default: 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
# Logging configuration # Logging configuration
parser.add_argument('--log-level', default='INFO', parser.add_argument(
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], "--log-level",
help='Logging level (default: INFO)') default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
return parser.parse_args() return parser.parse_args()
class DocumentManager: class DocumentManager:
"""Handles document operations and tracking""" """Handles document operations and tracking"""
def __init__(self, input_dir: str, supported_extensions: tuple = ('.txt', '.md')): def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
self.input_dir = Path(input_dir) self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions self.supported_extensions = supported_extensions
self.indexed_files = set() self.indexed_files = set()
@@ -67,7 +104,7 @@ class DocumentManager:
"""Scan input directory for new files""" """Scan input directory for new files"""
new_files = [] new_files = []
for ext in self.supported_extensions: for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f'*{ext}'): for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files: if file_path not in self.indexed_files:
new_files.append(file_path) new_files.append(file_path)
return new_files return new_files
@@ -80,6 +117,7 @@ class DocumentManager:
"""Check if file type is supported""" """Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions) return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models # Pydantic models
class SearchMode(str, Enum): class SearchMode(str, Enum):
naive = "naive" naive = "naive"
@@ -87,31 +125,38 @@ class SearchMode(str, Enum):
global_ = "global" global_ = "global"
hybrid = "hybrid" hybrid = "hybrid"
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
mode: SearchMode = SearchMode.hybrid mode: SearchMode = SearchMode.hybrid
stream: bool = False stream: bool = False
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str response: str
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str text: str
description: Optional[str] = None description: Optional[str] = None
class InsertResponse(BaseModel): class InsertResponse(BaseModel):
status: str status: str
message: str message: str
document_count: int document_count: int
def create_app(args): def create_app(args):
# Setup logging # Setup logging
logging.basicConfig(format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)) logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Initialize FastAPI app # Initialize FastAPI app
app = FastAPI( app = FastAPI(
title="LightRAG API", title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories" description="API for querying text using LightRAG with separate storage and input directories",
) )
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
@@ -127,7 +172,10 @@ def create_app(args):
llm_model_name=args.model, llm_model_name=args.model,
llm_model_max_async=args.max_async, llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens, llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={"host": args.ollama_host, "options": {"num_ctx": args.max_tokens}}, llm_model_kwargs={
"host": args.ollama_host,
"options": {"num_ctx": args.max_tokens},
},
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim, embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens, max_token_size=args.max_embed_tokens,
@@ -136,6 +184,7 @@ def create_app(args):
), ),
), ),
) )
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
"""Index all files in input directory during startup""" """Index all files in input directory during startup"""
@@ -144,7 +193,7 @@ def create_app(args):
for file_path in new_files: for file_path in new_files:
try: try:
# Use async file reading # Use async file reading
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read() content = await f.read()
# Use the async version of insert directly # Use the async version of insert directly
await rag.ainsert(content) await rag.ainsert(content)
@@ -168,7 +217,7 @@ def create_app(args):
for file_path in new_files: for file_path in new_files:
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
rag.insert(content) rag.insert(content)
doc_manager.mark_as_indexed(file_path) doc_manager.mark_as_indexed(file_path)
@@ -179,7 +228,7 @@ def create_app(args):
return { return {
"status": "success", "status": "success",
"indexed_count": indexed_count, "indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files) "total_documents": len(doc_manager.indexed_files),
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -191,7 +240,7 @@ def create_app(args):
if not doc_manager.is_supported_file(file.filename): if not doc_manager.is_supported_file(file.filename):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}" detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
) )
file_path = doc_manager.input_dir / file.filename file_path = doc_manager.input_dir / file.filename
@@ -207,7 +256,7 @@ def create_app(args):
return { return {
"status": "success", "status": "success",
"message": f"File uploaded and indexed: {file.filename}", "message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files) "total_documents": len(doc_manager.indexed_files),
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -217,7 +266,7 @@ def create_app(args):
try: try:
response = await rag.aquery( response = await rag.aquery(
request.query, request.query,
param=QueryParam(mode=request.mode, stream=request.stream) param=QueryParam(mode=request.mode, stream=request.stream),
) )
if request.stream: if request.stream:
@@ -234,8 +283,7 @@ def create_app(args):
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
try: try:
response = rag.query( response = rag.query(
request.query, request.query, param=QueryParam(mode=request.mode, stream=True)
param=QueryParam(mode=request.mode, stream=True)
) )
async def stream_generator(): async def stream_generator():
@@ -253,32 +301,29 @@ def create_app(args):
return InsertResponse( return InsertResponse(
status="success", status="success",
message="Text successfully inserted", message="Text successfully inserted",
document_count=len(rag) document_count=len(rag),
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse) @app.post("/documents/file", response_model=InsertResponse)
async def insert_file( async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
file: UploadFile = File(...),
description: str = Form(None)
):
try: try:
content = await file.read() content = await file.read()
if file.filename.endswith(('.txt', '.md')): if file.filename.endswith((".txt", ".md")):
text = content.decode('utf-8') text = content.decode("utf-8")
rag.insert(text) rag.insert(text)
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported" detail="Unsupported file type. Only .txt and .md files are supported",
) )
return InsertResponse( return InsertResponse(
status="success", status="success",
message=f"File '{file.filename}' successfully inserted", message=f"File '{file.filename}' successfully inserted",
document_count=len(rag) document_count=len(rag),
) )
except UnicodeDecodeError: except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported") raise HTTPException(status_code=400, detail="File encoding not supported")
@@ -294,8 +339,8 @@ def create_app(args):
for file in files: for file in files:
try: try:
content = await file.read() content = await file.read()
if file.filename.endswith(('.txt', '.md')): if file.filename.endswith((".txt", ".md")):
text = content.decode('utf-8') text = content.decode("utf-8")
rag.insert(text) rag.insert(text)
inserted_count += 1 inserted_count += 1
else: else:
@@ -310,7 +355,7 @@ def create_app(args):
return InsertResponse( return InsertResponse(
status="success" if inserted_count > 0 else "partial_success", status="success" if inserted_count > 0 else "partial_success",
message=status_message, message=status_message,
document_count=len(rag) document_count=len(rag),
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -324,12 +369,11 @@ def create_app(args):
return InsertResponse( return InsertResponse(
status="success", status="success",
message="All documents cleared successfully", message="All documents cleared successfully",
document_count=0 document_count=0,
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/health") @app.get("/health")
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
@@ -342,14 +386,16 @@ def create_app(args):
"model": args.model, "model": args.model,
"embedding_model": args.embedding_model, "embedding_model": args.embedding_model,
"max_tokens": args.max_tokens, "max_tokens": args.max_tokens,
"ollama_host": args.ollama_host "ollama_host": args.ollama_host,
} },
} }
return app return app
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
import uvicorn import uvicorn
app = create_app(args) app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)

View File

@@ -1,8 +1,6 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
import asyncio import asyncio
import os
import logging import logging
import argparse import argparse
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
@@ -13,49 +11,77 @@ from enum import Enum
from pathlib import Path from pathlib import Path
import shutil import shutil
import aiofiles import aiofiles
from ascii_colors import ASCIIColors, trace_exception from ascii_colors import trace_exception
import numpy as np
import nest_asyncio import nest_asyncio
# Apply nest_asyncio to solve event loop issues # Apply nest_asyncio to solve event loop issues
nest_asyncio.apply() nest_asyncio.apply()
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with OpenAI integration" description="LightRAG FastAPI Server with OpenAI integration"
) )
# Server configuration # Server configuration
parser.add_argument('--host', default='0.0.0.0', help='Server host (default: 0.0.0.0)') parser.add_argument(
parser.add_argument('--port', type=int, default=9621, help='Server port (default: 9621)') "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration # Directory configuration
parser.add_argument('--working-dir', default='./rag_storage', parser.add_argument(
help='Working directory for RAG storage (default: ./rag_storage)') "--working-dir",
parser.add_argument('--input-dir', default='./inputs', default="./rag_storage",
help='Directory containing input documents (default: ./inputs)') help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# Model configuration # Model configuration
parser.add_argument('--model', default='gpt-4', help='OpenAI model name (default: gpt-4)') parser.add_argument(
parser.add_argument('--embedding-model', default='text-embedding-3-large', "--model", default="gpt-4", help="OpenAI model name (default: gpt-4)"
help='OpenAI embedding model (default: text-embedding-3-large)') )
parser.add_argument(
"--embedding-model",
default="text-embedding-3-large",
help="OpenAI embedding model (default: text-embedding-3-large)",
)
# RAG configuration # RAG configuration
parser.add_argument('--max-tokens', type=int, default=32768, help='Maximum token size (default: 32768)') parser.add_argument(
parser.add_argument('--max-embed-tokens', type=int, default=8192, "--max-tokens",
help='Maximum embedding token size (default: 8192)') type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
# Logging configuration # Logging configuration
parser.add_argument('--log-level', default='INFO', parser.add_argument(
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], "--log-level",
help='Logging level (default: INFO)') default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
return parser.parse_args() return parser.parse_args()
class DocumentManager: class DocumentManager:
"""Handles document operations and tracking""" """Handles document operations and tracking"""
def __init__(self, input_dir: str, supported_extensions: tuple = ('.txt', '.md')): def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
self.input_dir = Path(input_dir) self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions self.supported_extensions = supported_extensions
self.indexed_files = set() self.indexed_files = set()
@@ -67,7 +93,7 @@ class DocumentManager:
"""Scan input directory for new files""" """Scan input directory for new files"""
new_files = [] new_files = []
for ext in self.supported_extensions: for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f'*{ext}'): for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files: if file_path not in self.indexed_files:
new_files.append(file_path) new_files.append(file_path)
return new_files return new_files
@@ -80,6 +106,7 @@ class DocumentManager:
"""Check if file type is supported""" """Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions) return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models # Pydantic models
class SearchMode(str, Enum): class SearchMode(str, Enum):
naive = "naive" naive = "naive"
@@ -87,37 +114,45 @@ class SearchMode(str, Enum):
global_ = "global" global_ = "global"
hybrid = "hybrid" hybrid = "hybrid"
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
mode: SearchMode = SearchMode.hybrid mode: SearchMode = SearchMode.hybrid
stream: bool = False stream: bool = False
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str response: str
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str text: str
description: Optional[str] = None description: Optional[str] = None
class InsertResponse(BaseModel): class InsertResponse(BaseModel):
status: str status: str
message: str message: str
document_count: int document_count: int
async def get_embedding_dim(embedding_model: str) -> int: async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model""" """Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."] test_text = ["This is a test sentence."]
embedding = await openai_embedding(test_text, model=embedding_model) embedding = await openai_embedding(test_text, model=embedding_model)
return embedding.shape[1] return embedding.shape[1]
def create_app(args): def create_app(args):
# Setup logging # Setup logging
logging.basicConfig(format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)) logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Initialize FastAPI app # Initialize FastAPI app
app = FastAPI( app = FastAPI(
title="LightRAG API", title="LightRAG API",
description="API for querying text using LightRAG with OpenAI integration" description="API for querying text using LightRAG with OpenAI integration",
) )
# Create working directory if it doesn't exist # Create working directory if it doesn't exist
@@ -129,6 +164,18 @@ def create_app(args):
# Get embedding dimensions # Get embedding dimensions
embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model)) embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model))
async def async_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
):
"""Async wrapper for OpenAI completion"""
return await openai_complete_if_cache(
args.model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
# Initialize RAG with OpenAI configuration # Initialize RAG with OpenAI configuration
rag = LightRAG( rag = LightRAG(
working_dir=args.working_dir, working_dir=args.working_dir,
@@ -142,15 +189,6 @@ def create_app(args):
), ),
) )
async def async_openai_complete(prompt, system_prompt=None, history_messages=[], **kwargs):
"""Async wrapper for OpenAI completion"""
return await openai_complete_if_cache(
args.model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
"""Index all files in input directory during startup""" """Index all files in input directory during startup"""
@@ -159,7 +197,7 @@ def create_app(args):
for file_path in new_files: for file_path in new_files:
try: try:
# Use async file reading # Use async file reading
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read() content = await f.read()
# Use the async version of insert directly # Use the async version of insert directly
await rag.ainsert(content) await rag.ainsert(content)
@@ -183,7 +221,7 @@ def create_app(args):
for file_path in new_files: for file_path in new_files:
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
rag.insert(content) rag.insert(content)
doc_manager.mark_as_indexed(file_path) doc_manager.mark_as_indexed(file_path)
@@ -194,7 +232,7 @@ def create_app(args):
return { return {
"status": "success", "status": "success",
"indexed_count": indexed_count, "indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files) "total_documents": len(doc_manager.indexed_files),
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -206,7 +244,7 @@ def create_app(args):
if not doc_manager.is_supported_file(file.filename): if not doc_manager.is_supported_file(file.filename):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}" detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
) )
file_path = doc_manager.input_dir / file.filename file_path = doc_manager.input_dir / file.filename
@@ -222,7 +260,7 @@ def create_app(args):
return { return {
"status": "success", "status": "success",
"message": f"File uploaded and indexed: {file.filename}", "message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files) "total_documents": len(doc_manager.indexed_files),
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -232,7 +270,7 @@ def create_app(args):
try: try:
response = await rag.aquery( response = await rag.aquery(
request.query, request.query,
param=QueryParam(mode=request.mode, stream=request.stream) param=QueryParam(mode=request.mode, stream=request.stream),
) )
if request.stream: if request.stream:
@@ -249,8 +287,7 @@ def create_app(args):
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
try: try:
response = rag.query( response = rag.query(
request.query, request.query, param=QueryParam(mode=request.mode, stream=True)
param=QueryParam(mode=request.mode, stream=True)
) )
async def stream_generator(): async def stream_generator():
@@ -268,32 +305,29 @@ def create_app(args):
return InsertResponse( return InsertResponse(
status="success", status="success",
message="Text successfully inserted", message="Text successfully inserted",
document_count=len(rag) document_count=len(rag),
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse) @app.post("/documents/file", response_model=InsertResponse)
async def insert_file( async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
file: UploadFile = File(...),
description: str = Form(None)
):
try: try:
content = await file.read() content = await file.read()
if file.filename.endswith(('.txt', '.md')): if file.filename.endswith((".txt", ".md")):
text = content.decode('utf-8') text = content.decode("utf-8")
rag.insert(text) rag.insert(text)
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported" detail="Unsupported file type. Only .txt and .md files are supported",
) )
return InsertResponse( return InsertResponse(
status="success", status="success",
message=f"File '{file.filename}' successfully inserted", message=f"File '{file.filename}' successfully inserted",
document_count=len(rag) document_count=len(rag),
) )
except UnicodeDecodeError: except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported") raise HTTPException(status_code=400, detail="File encoding not supported")
@@ -309,8 +343,8 @@ def create_app(args):
for file in files: for file in files:
try: try:
content = await file.read() content = await file.read()
if file.filename.endswith(('.txt', '.md')): if file.filename.endswith((".txt", ".md")):
text = content.decode('utf-8') text = content.decode("utf-8")
rag.insert(text) rag.insert(text)
inserted_count += 1 inserted_count += 1
else: else:
@@ -325,7 +359,7 @@ def create_app(args):
return InsertResponse( return InsertResponse(
status="success" if inserted_count > 0 else "partial_success", status="success" if inserted_count > 0 else "partial_success",
message=status_message, message=status_message,
document_count=len(rag) document_count=len(rag),
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -339,7 +373,7 @@ def create_app(args):
return InsertResponse( return InsertResponse(
status="success", status="success",
message="All documents cleared successfully", message="All documents cleared successfully",
document_count=0 document_count=0,
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -356,14 +390,16 @@ def create_app(args):
"model": args.model, "model": args.model,
"embedding_model": args.embedding_model, "embedding_model": args.embedding_model,
"max_tokens": args.max_tokens, "max_tokens": args.max_tokens,
"embedding_dim": embedding_dim "embedding_dim": embedding_dim,
} },
} }
return app return app
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
import uvicorn import uvicorn
app = create_app(args) app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)

View File

@@ -1,4 +1,4 @@
fastapi
uvicorn
python-multipart
ascii_colors ascii_colors
fastapi
python-multipart
uvicorn