From 387be31f090735d3ba95ee12655447c16b54cfde Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 19 Jan 2025 05:19:02 +0800 Subject: [PATCH] Refactor embedding function initialization and remove start-server.sh - Simplified RAG initialization logic by deduplicating embedding function - Removed start-server.sh script which is not needed - No functional changes to the application --- lightrag/api/lightrag_ollama.py | 924 -------------------------------- lightrag/api/lightrag_server.py | 78 +-- setup.py | 1 - start-server.sh | 3 - 4 files changed, 28 insertions(+), 978 deletions(-) delete mode 100644 lightrag/api/lightrag_ollama.py delete mode 100755 start-server.sh diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py deleted file mode 100644 index fc7ae29c..00000000 --- a/lightrag/api/lightrag_ollama.py +++ /dev/null @@ -1,924 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request -from pydantic import BaseModel -import logging -import argparse -import json -import time -import re -from typing import List, Dict, Any, Optional -from lightrag import LightRAG, QueryParam -from lightrag.llm import openai_complete_if_cache, ollama_embedding - -from lightrag.utils import EmbeddingFunc -from enum import Enum -from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception -import os - -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware - -from starlette.status import HTTP_403_FORBIDDEN - -from dotenv import load_dotenv - -load_dotenv() - - -def estimate_tokens(text: str) -> int: - """Estimate the number of tokens in text - Chinese characters: approximately 1.5 tokens per character - English characters: approximately 0.25 tokens per character - """ - # Use regex to match Chinese and non-Chinese characters separately - chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) - non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) - - # Calculate estimated token count - tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 - - return int(tokens) - - -# Constants for model information -LIGHTRAG_NAME = "lightrag" -LIGHTRAG_TAG = "latest" -LIGHTRAG_MODEL = "lightrag:latest" -LIGHTRAG_SIZE = 7365960935 -LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" -LIGHTRAG_DIGEST = "sha256:lightrag" - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - "deepseek-chat", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=os.getenv("DEEPSEEK_API_KEY"), - base_url=os.getenv("DEEPSEEK_ENDPOINT"), - **kwargs, - ) - - -def get_default_host(binding_type: str) -> str: - default_hosts = { - "ollama": "http://m4.lan.znipower.com:11434", - "lollms": "http://localhost:9600", - "azure_openai": "https://api.openai.com/v1", - "openai": os.getenv("DEEPSEEK_ENDPOINT"), - } - return default_hosts.get( - binding_type, "http://localhost:11434" - ) # fallback to ollama if unknown - - -def parse_args(): - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) - - # Start by the bindings - parser.add_argument( - "--llm-binding", - default="ollama", - help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", - ) - parser.add_argument( - "--embedding-binding", - default="ollama", - help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", - ) - - # Parse just these arguments first - temp_args, _ = parser.parse_known_args() - - # Add remaining arguments with dynamic defaults for hosts - # Server configuration - parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", - ) - - # LLM Model configuration - default_llm_host = get_default_host(temp_args.llm_binding) - parser.add_argument( - "--llm-binding-host", - default=default_llm_host, - help=f"llm server host URL (default: {default_llm_host})", - ) - - parser.add_argument( - "--llm-model", - default="mistral-nemo:latest", - help="LLM model name (default: mistral-nemo:latest)", - ) - - # Embedding model configuration - default_embedding_host = get_default_host(temp_args.embedding_binding) - parser.add_argument( - "--embedding-binding-host", - default=default_embedding_host, - help=f"embedding server host URL (default: {default_embedding_host})", - ) - - parser.add_argument( - "--embedding-model", - default="bge-m3:latest", - help="Embedding model name (default: bge-m3:latest)", - ) - - def timeout_type(value): - if value is None or value == "None": - return None - return int(value) - - parser.add_argument( - "--timeout", - default=None, - type=timeout_type, - help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", - ) - # RAG configuration - parser.add_argument( - "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" - ) - parser.add_argument( - "--max-tokens", - type=int, - default=32768, - help="Maximum token size (default: 32768)", - ) - parser.add_argument( - "--embedding-dim", - type=int, - default=1024, - help="Embedding dimensions (default: 1024)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", - ) - - parser.add_argument( - "--key", - type=str, - help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, - ) - - # Optional https parameters - parser.add_argument( - "--ssl", action="store_true", help="Enable HTTPS (default: False)" - ) - parser.add_argument( - "--ssl-certfile", - default=None, - help="Path to SSL certificate file (required if --ssl is enabled)", - ) - parser.add_argument( - "--ssl-keyfile", - default=None, - help="Path to SSL private key file (required if --ssl is enabled)", - ) - return parser.parse_args() - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -# Pydantic models -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global" - hybrid = "hybrid" - mix = "mix" - - -# Ollama API compatible models -class OllamaMessage(BaseModel): - role: str - content: str - images: Optional[List[str]] = None - - -class OllamaChatRequest(BaseModel): - model: str = LIGHTRAG_MODEL - messages: List[OllamaMessage] - stream: bool = True # Default to streaming mode - options: Optional[Dict[str, Any]] = None - - -class OllamaChatResponse(BaseModel): - model: str - created_at: str - message: OllamaMessage - done: bool - - -class OllamaVersionResponse(BaseModel): - version: str - - -class OllamaModelDetails(BaseModel): - parent_model: str - format: str - family: str - families: List[str] - parameter_size: str - quantization_level: str - - -class OllamaModel(BaseModel): - name: str - model: str - size: int - digest: str - modified_at: str - details: OllamaModelDetails - - -class OllamaTagResponse(BaseModel): - models: List[OllamaModel] - - -# Original LightRAG models -class QueryRequest(BaseModel): - query: str - mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False - - -class QueryResponse(BaseModel): - response: str - - -class InsertTextRequest(BaseModel): - text: str - description: Optional[str] = None - - -class InsertResponse(BaseModel): - status: str - message: str - document_count: int - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -def create_app(args): - # Verify that bindings arer correctly setup - if args.llm_binding not in ["lollms", "ollama", "openai"]: - raise Exception("llm binding not supported") - - if args.embedding_binding not in ["lollms", "ollama", "openai"]: - raise Exception("embedding binding not supported") - - # Add SSL validation - if args.ssl: - if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception( - "SSL certificate and key files must be provided when SSL is enabled" - ) - if not os.path.exists(args.ssl_certfile): - raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") - if not os.path.exists(args.ssl_keyfile): - raise Exception(f"SSL key file not found: {args.ssl_keyfile}") - - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - - # Check if API key is provided either through env var or args - api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - - # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" - + "(With authentication)" - if api_key - else "", - version="1.0.1", - openapi_tags=[{"name": "api"}], - ) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) - - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) - - # Initialize RAG - rag = LightRAG( - working_dir=args.working_dir, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=lambda texts: ollama_embedding( - texts, - embed_model="bge-m3:latest", - host="http://m4.lan.znipower.com:11434", - ), - ), - ) - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" - try: - new_files = doc_manager.scan_directory() - for file_path in new_files: - try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") - - logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """Manually trigger scanning for new documents""" - try: - new_files = doc_manager.scan_directory() - indexed_count = 0 - - for file_path in new_files: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): - """Upload a file to the input directory""" - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - ), - ) - - # If response is a string (e.g. cache hit), return directly - if isinstance(response, str): - return QueryResponse(response=response) - - # If it's an async generator, decide whether to stream based on stream parameter - if request.stream: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) - else: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - try: - response = await rag.aquery( # Use aquery instead of query, and add await - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - ), - ) - - from fastapi.responses import StreamingResponse - - async def stream_generator(): - if isinstance(response, str): - # If it's a string, send it all at once - yield f"{json.dumps({'response': response})}\n" - else: - # If it's an async generator, send chunks one by one - try: - async for chunk in response: - if chunk: # Only send non-empty content - yield f"{json.dumps({'response': chunk})}\n" - except Exception as e: - logging.error(f"Streaming error: {str(e)}") - yield f"{json.dumps({'error': str(e)})}\n" - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - "X-Accel-Buffering": "no", # Disable Nginx buffering - }, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - try: - await rag.ainsert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=1, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - try: - content = await file.read() - - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) - else: - raise HTTPException( - status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", - ) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", - message=status_message, - document_count=len(files), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - # Ollama compatible API endpoints - @app.get("/api/version") - async def get_version(): - """Get Ollama version information""" - return OllamaVersionResponse(version="0.5.4") - - @app.get("/api/tags") - async def get_tags(): - """Get available models""" - return OllamaTagResponse( - models=[ - { - "name": LIGHTRAG_MODEL, - "model": LIGHTRAG_MODEL, - "size": LIGHTRAG_SIZE, - "digest": LIGHTRAG_DIGEST, - "modified_at": LIGHTRAG_CREATED_AT, - "details": { - "parent_model": "", - "format": "gguf", - "family": LIGHTRAG_NAME, - "families": [LIGHTRAG_NAME], - "parameter_size": "13B", - "quantization_level": "Q4_0", - }, - } - ] - ) - - def parse_query_mode(query: str) -> tuple[str, SearchMode]: - """Parse query prefix to determine search mode - Returns tuple of (cleaned_query, search_mode) - """ - mode_map = { - "/local ": SearchMode.local, - "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword - "/naive ": SearchMode.naive, - "/hybrid ": SearchMode.hybrid, - "/mix ": SearchMode.mix, - } - - for prefix, mode in mode_map.items(): - if query.startswith(prefix): - # After removing prefix an leading spaces - cleaned_query = query[len(prefix) :].lstrip() - return cleaned_query, mode - - return query, SearchMode.hybrid - - @app.post("/api/chat") - async def chat(raw_request: Request, request: OllamaChatRequest): - """Handle chat completion requests""" - try: - # Get all messages - messages = request.messages - if not messages: - raise HTTPException(status_code=400, detail="No messages provided") - - # Get the last message as query - query = messages[-1].content - - # 解析查询模式 - cleaned_query, mode = parse_query_mode(query) - - # 开始计时 - start_time = time.time_ns() - - # 计算输入token数量 - prompt_tokens = estimate_tokens(cleaned_query) - - # 调用RAG进行查询 - query_param = QueryParam( - mode=mode, stream=request.stream, only_need_context=False - ) - - if request.stream: - from fastapi.responses import StreamingResponse - - response = await rag.aquery( # Need await to get async generator - cleaned_query, param=query_param - ) - - async def stream_generator(): - try: - first_chunk_time = None - last_chunk_time = None - total_response = "" - - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time - total_response = response - - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": response, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - return # Ensure the generator ends immediately after sending the completion marker - except Exception as e: - logging.error(f"Error in stream_generator: {str(e)}") - raise - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - }, - ) - else: - first_chunk_time = time.time_ns() - response_text = await rag.aquery(cleaned_query, param=query_param) - last_chunk_time = time.time_ns() - - if not response_text: - response_text = "No response generated" - - completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - return { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": str(response_text), - "images": None, - }, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.get("/health", dependencies=[Depends(optional_api_key)]) - async def get_status(): - """Get current system status""" - return { - "status": "healthy", - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "indexed_files": len(doc_manager.indexed_files), - "configuration": { - # LLM configuration binding/host address (if applicable)/model (if applicable) - "llm_binding": args.llm_binding, - "llm_binding_host": args.llm_binding_host, - "llm_model": args.llm_model, - # embedding model configuration binding/host address (if applicable)/model (if applicable) - "embedding_binding": args.embedding_binding, - "embedding_binding_host": args.embedding_binding_host, - "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, - }, - } - - return app - - -def main(): - args = parse_args() - import uvicorn - - app = create_app(args) - uvicorn_config = { - "app": app, - "host": args.host, - "port": args.port, - } - if args.ssl: - uvicorn_config.update( - { - "ssl_certfile": args.ssl_certfile, - "ssl_keyfile": args.ssl_keyfile, - } - ) - uvicorn.run(**uvicorn_config) - - -if __name__ == "__main__": - main() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 25e65879..3f2e425e 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -615,6 +615,32 @@ def create_app(args): **kwargs, ) + embedding_func = EmbeddingFunc( + embedding_dim=args.embedding_dim, + max_token_size=args.max_embed_tokens, + func=lambda texts: lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.embedding_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.embedding_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.embedding_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ), + ) + # Initialize RAG if args.llm_binding in ["lollms", "ollama"] : rag = LightRAG( @@ -630,31 +656,7 @@ def create_app(args): "timeout": args.timeout, "options": {"num_ctx": args.max_tokens}, }, - embedding_func=EmbeddingFunc( - embedding_dim=args.embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) - if args.embedding_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) - if args.embedding_binding == "ollama" - else azure_openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai - ) - if args.embedding_binding == "azure_openai" - else openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai - ), - ), + embedding_func=embedding_func, ) else : rag = LightRAG( @@ -662,31 +664,7 @@ def create_app(args): llm_model_func=azure_openai_model_complete if args.llm_binding == "azure_openai" else openai_alike_model_complete, - embedding_func=EmbeddingFunc( - embedding_dim=args.embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) - if args.embedding_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) - if args.embedding_binding == "ollama" - else azure_openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai - ) - if args.embedding_binding == "azure_openai" - else openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai - ), - ), + embedding_func=embedding_func, ) async def index_file(file_path: Union[str, Path]) -> None: diff --git a/setup.py b/setup.py index b5850d26..38eff646 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,6 @@ setuptools.setup( entry_points={ "console_scripts": [ "lightrag-server=lightrag.api.lightrag_server:main [api]", - "lightrag-ollama=lightrag.api.lightrag_ollama:main [api]", ], }, ) diff --git a/start-server.sh b/start-server.sh deleted file mode 100755 index 4e143f37..00000000 --- a/start-server.sh +++ /dev/null @@ -1,3 +0,0 @@ -. venv/bin/activate - -lightrag-ollama --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024