From 8ea179a98b6861597f8eb01b3030385cad0ce3f3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 19 Jan 2025 04:44:30 +0800 Subject: [PATCH] Migrate Ollama API to lightrag_server.py --- .env.example | 6 +- lightrag/api/lightrag_server.py | 495 ++++++++++++++++++++++++++++---- 2 files changed, 438 insertions(+), 63 deletions(-) diff --git a/.env.example b/.env.example index 7d5c0fe5..68cb9d13 100644 --- a/.env.example +++ b/.env.example @@ -25,9 +25,9 @@ EMBEDDING_BINDING_HOST=http://host.docker.internal:11434 EMBEDDING_MODEL=bge-m3:latest # Lollms example -EMBEDDING_BINDING=lollms -EMBEDDING_BINDING_HOST=http://host.docker.internal:9600 -EMBEDDING_MODEL=bge-m3:latest +# EMBEDDING_BINDING=lollms +# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600 +# EMBEDDING_MODEL=bge-m3:latest # RAG Configuration MAX_ASYNC=4 diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b898277a..25e65879 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1,7 +1,11 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form +from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request from pydantic import BaseModel import logging import argparse +import json +import time +import re +from typing import List, Dict, Any, Optional, Union from lightrag import LightRAG, QueryParam from lightrag.llm import lollms_model_complete, lollms_embed from lightrag.llm import ollama_model_complete, ollama_embed @@ -10,7 +14,6 @@ from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc -from typing import Optional, List, Union, Any from enum import Enum from pathlib import Path import shutil @@ -28,16 +31,41 @@ import pipmaster as pm from dotenv import load_dotenv +load_dotenv() + +def estimate_tokens(text: str) -> int: + """Estimate the number of tokens in text + Chinese characters: approximately 1.5 tokens per character + English characters: approximately 0.25 tokens per character + """ + # Use regex to match Chinese and non-Chinese characters separately + chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) + non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) + + # Calculate estimated token count + tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 + + return int(tokens) + + +# Constants for emulated Ollama model information +LIGHTRAG_NAME = "lightrag" +LIGHTRAG_TAG = "latest" +LIGHTRAG_MODEL = "lightrag:latest" +LIGHTRAG_SIZE = 7365960935 # it's a dummy value +LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" +LIGHTRAG_DIGEST = "sha256:lightrag" + def get_default_host(binding_type: str) -> str: default_hosts = { - "ollama": "http://localhost:11434", - "lollms": "http://localhost:9600", - "azure_openai": "https://api.openai.com/v1", - "openai": "https://api.openai.com/v1", + "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), + "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), + "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), + "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), } return default_hosts.get( - binding_type, "http://localhost:11434" + binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") ) # fallback to ollama if unknown @@ -214,9 +242,7 @@ def parse_args() -> argparse.Namespace: Returns: argparse.Namespace: Parsed arguments """ - # Load environment variables from .env file - load_dotenv() - + parser = argparse.ArgumentParser( description="LightRAG FastAPI Server with separate working and input directories" ) @@ -409,6 +435,53 @@ class SearchMode(str, Enum): local = "local" global_ = "global" hybrid = "hybrid" + mix = "mix" + + +class OllamaMessage(BaseModel): + role: str + content: str + images: Optional[List[str]] = None + + +class OllamaChatRequest(BaseModel): + model: str = LIGHTRAG_MODEL + messages: List[OllamaMessage] + stream: bool = True # Default to streaming mode + options: Optional[Dict[str, Any]] = None + + +class OllamaChatResponse(BaseModel): + model: str + created_at: str + message: OllamaMessage + done: bool + + +class OllamaVersionResponse(BaseModel): + version: str + + +class OllamaModelDetails(BaseModel): + parent_model: str + format: str + family: str + families: List[str] + parameter_size: str + quantization_level: str + + +class OllamaModel(BaseModel): + name: str + model: str + size: int + digest: str + modified_at: str + details: OllamaModelDetails + + +class OllamaTagResponse(BaseModel): + models: List[OllamaModel] class QueryRequest(BaseModel): @@ -514,50 +587,107 @@ def create_app(args): # Initialize document manager doc_manager = DocumentManager(args.input_dir) + + async def openai_alike_model_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + ) -> str: + return await openai_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + base_url=args.llm_binding_host, + api_key=os.getenv("OPENAI_API_KEY"), + **kwargs, + ) + + async def azure_openai_model_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + ) -> str: + return await azure_openai_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + base_url=args.llm_binding_host, + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"), + **kwargs, + ) + # Initialize RAG - rag = LightRAG( - working_dir=args.working_dir, - llm_model_func=lollms_model_complete - if args.llm_binding == "lollms" - else ollama_model_complete - if args.llm_binding == "ollama" - else azure_openai_complete_if_cache - if args.llm_binding == "azure_openai" - else openai_complete_if_cache, - llm_model_name=args.llm_model, - llm_model_max_async=args.max_async, - llm_model_max_token_size=args.max_tokens, - llm_model_kwargs={ - "host": args.llm_binding_host, - "timeout": args.timeout, - "options": {"num_ctx": args.max_tokens}, - }, - embedding_func=EmbeddingFunc( - embedding_dim=args.embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) + if args.llm_binding in ["lollms", "ollama"] : + rag = LightRAG( + working_dir=args.working_dir, + llm_model_func=lollms_model_complete if args.llm_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) - if args.llm_binding == "ollama" - else azure_openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai - ) - if args.llm_binding == "azure_openai" - else openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai + else ollama_model_complete, + llm_model_name=args.llm_model, + llm_model_max_async=args.max_async, + llm_model_max_token_size=args.max_tokens, + llm_model_kwargs={ + "host": args.llm_binding_host, + "timeout": args.timeout, + "options": {"num_ctx": args.max_tokens}, + }, + embedding_func=EmbeddingFunc( + embedding_dim=args.embedding_dim, + max_token_size=args.max_embed_tokens, + func=lambda texts: lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.embedding_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.embedding_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.embedding_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ), ), - ), - ) + ) + else : + rag = LightRAG( + working_dir=args.working_dir, + llm_model_func=azure_openai_model_complete + if args.llm_binding == "azure_openai" + else openai_alike_model_complete, + embedding_func=EmbeddingFunc( + embedding_dim=args.embedding_dim, + max_token_size=args.max_embed_tokens, + func=lambda texts: lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.embedding_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.embedding_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.embedding_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ), + ), + ) async def index_file(file_path: Union[str, Path]) -> None: """Index all files inside the folder with support for multiple file formats @@ -592,7 +722,7 @@ def create_app(args): case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") - from pypdf2 import PdfReader + from PyPDF2 import PdfReader # PDF handling reader = PdfReader(str(file_path)) @@ -711,13 +841,21 @@ def create_app(args): ), ) + # If response is a string (e.g. cache hit), return directly + if isinstance(response, str): + return QueryResponse(response=response) + + # If it's an async generator, decide whether to stream based on stream parameter if request.stream: result = "" async for chunk in response: result += chunk return QueryResponse(response=result) else: - return QueryResponse(response=response) + result = "" + async for chunk in response: + result += chunk + return QueryResponse(response=result) except Exception as e: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -725,7 +863,7 @@ def create_app(args): @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: - response = rag.query( + response = await rag.aquery( # Use aquery instead of query, and add await request.query, param=QueryParam( mode=request.mode, @@ -734,12 +872,37 @@ def create_app(args): ), ) - async def stream_generator(): - async for chunk in response: - yield chunk + from fastapi.responses import StreamingResponse - return stream_generator() + async def stream_generator(): + if isinstance(response, str): + # If it's a string, send it all at once + yield f"{json.dumps({'response': response})}\n" + else: + # If it's an async generator, send chunks one by one + try: + async for chunk in response: + if chunk: # Only send non-empty content + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logging.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + "X-Accel-Buffering": "no", # Disable Nginx buffering + }, + ) except Exception as e: + trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @app.post( @@ -790,7 +953,7 @@ def create_app(args): case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") - from pypdf2 import PdfReader + from PyPDF2 import PdfReader from io import BytesIO # Read PDF from memory @@ -897,7 +1060,7 @@ def create_app(args): case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") - from pypdf2 import PdfReader + from PyPDF2 import PdfReader from io import BytesIO pdf_content = await file.read() @@ -993,6 +1156,218 @@ def create_app(args): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + # Ollama compatible API endpoints + @app.get("/api/version") + async def get_version(): + """Get Ollama version information""" + return OllamaVersionResponse(version="0.5.4") + + @app.get("/api/tags") + async def get_tags(): + """Get available models""" + return OllamaTagResponse( + models=[ + { + "name": LIGHTRAG_MODEL, + "model": LIGHTRAG_MODEL, + "size": LIGHTRAG_SIZE, + "digest": LIGHTRAG_DIGEST, + "modified_at": LIGHTRAG_CREATED_AT, + "details": { + "parent_model": "", + "format": "gguf", + "family": LIGHTRAG_NAME, + "families": [LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0", + }, + } + ] + ) + + def parse_query_mode(query: str) -> tuple[str, SearchMode]: + """Parse query prefix to determine search mode + Returns tuple of (cleaned_query, search_mode) + """ + mode_map = { + "/local ": SearchMode.local, + "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword + "/naive ": SearchMode.naive, + "/hybrid ": SearchMode.hybrid, + "/mix ": SearchMode.mix, + } + + for prefix, mode in mode_map.items(): + if query.startswith(prefix): + # After removing prefix an leading spaces + cleaned_query = query[len(prefix) :].lstrip() + return cleaned_query, mode + + return query, SearchMode.hybrid + + @app.post("/api/chat") + async def chat(raw_request: Request, request: OllamaChatRequest): + """Handle chat completion requests""" + try: + # Get all messages + messages = request.messages + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") + + # Get the last message as query + query = messages[-1].content + + # 解析查询模式 + cleaned_query, mode = parse_query_mode(query) + + # 开始计时 + start_time = time.time_ns() + + # 计算输入token数量 + prompt_tokens = estimate_tokens(cleaned_query) + + # 调用RAG进行查询 + query_param = QueryParam( + mode=mode, stream=request.stream, only_need_context=False + ) + + if request.stream: + from fastapi.responses import StreamingResponse + + response = await rag.aquery( # Need await to get async generator + cleaned_query, param=query_param + ) + + async def stream_generator(): + try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": response, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return # Ensure the generator ends immediately after sending the completion marker + except Exception as e: + logging.error(f"Error in stream_generator: {str(e)}") + raise + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + response_text = await rag.aquery(cleaned_query, param=query_param) + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": str(response_text), + "images": None, + }, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status"""