diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 42ae68f4..6f1ec9a4 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form from pydantic import BaseModel import logging import argparse +from typing import List, Dict, Any, Optional from lightrag import LightRAG, QueryParam -# from lightrag.llm import lollms_model_complete, lollms_embed -# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding from lightrag.llm import openai_complete_if_cache, ollama_embedding -# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding from lightrag.utils import EmbeddingFunc -from typing import Optional, List from enum import Enum from pathlib import Path import shutil @@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN from dotenv import load_dotenv load_dotenv() +# Constants for model information +LIGHTRAG_NAME = "lightrag" +LIGHTRAG_TAG = "latest" +LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" +LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" +LIGHTRAG_DIGEST = "sha256:lightrag" + async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: @@ -219,21 +223,43 @@ class DocumentManager: class SearchMode(str, Enum): naive = "naive" local = "local" - global_ = "global" + global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global" hybrid = "hybrid" +# Ollama API compatible models +class OllamaMessage(BaseModel): + role: str + content: str +class OllamaChatRequest(BaseModel): + model: str = LIGHTRAG_MODEL + messages: List[OllamaMessage] + stream: bool = False + options: Optional[Dict[str, Any]] = None + +class OllamaChatResponse(BaseModel): + model: str + created_at: str + message: OllamaMessage + done: bool + +class OllamaVersionResponse(BaseModel): + version: str + build: str = "default" + +class OllamaTagResponse(BaseModel): + models: List[Dict[str, str]] + +# Original LightRAG models class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid stream: bool = False only_need_context: bool = False - class QueryResponse(BaseModel): response: str - class InsertTextRequest(BaseModel): text: str description: Optional[str] = None @@ -555,6 +581,101 @@ def create_app(args): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + # Ollama compatible API endpoints + @app.get("/api/version") + async def get_version(): + """Get Ollama version information""" + return OllamaVersionResponse( + version="0.1.0" + ) + + @app.get("/api/tags") + async def get_tags(): + """Get available models""" + return OllamaTagResponse( + models=[{ + "name": LIGHTRAG_NAME, + "tag": LIGHTRAG_TAG, + "size": 0, + "digest": LIGHTRAG_DIGEST, + "modified_at": LIGHTRAG_CREATED_AT + }] + ) + + def parse_query_mode(query: str) -> tuple[str, SearchMode]: + """Parse query prefix to determine search mode + Returns tuple of (cleaned_query, search_mode) + """ + mode_map = { + "/local ": SearchMode.local, + "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword + "/naive ": SearchMode.naive, + "/hybrid ": SearchMode.hybrid + } + + for prefix, mode in mode_map.items(): + if query.startswith(prefix): + return query[len(prefix):], mode + + return query, SearchMode.hybrid + + @app.post("/api/chat") + async def chat(request: OllamaChatRequest): + """Handle chat completion requests""" + try: + # Convert chat format to query + query = request.messages[-1].content if request.messages else "" + + # Parse query mode and clean query + cleaned_query, mode = parse_query_mode(query) + + # Call RAG with determined mode + response = await rag.aquery( + cleaned_query, + param=QueryParam( + mode=mode, + stream=request.stream + ) + ) + + if request.stream: + async def stream_generator(): + result = "" + async for chunk in response: + result += chunk + yield OllamaChatResponse( + model=LIGHTRAG_MODEL, + created_at=LIGHTRAG_CREATED_AT, + message=OllamaMessage( + role="assistant", + content=chunk + ), + done=False + ) + # Send final message + yield OllamaChatResponse( + model=LIGHTRAG_MODEL, + created_at=LIGHTRAG_CREATED_AT, + message=OllamaMessage( + role="assistant", + content=result + ), + done=True + ) + return stream_generator() + else: + return OllamaChatResponse( + model=LIGHTRAG_MODEL, + created_at=LIGHTRAG_CREATED_AT, + message=OllamaMessage( + role="assistant", + content=response + ), + done=True + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" diff --git a/setup.py b/setup.py index 38eff646..b5850d26 100644 --- a/setup.py +++ b/setup.py @@ -101,6 +101,7 @@ setuptools.setup( entry_points={ "console_scripts": [ "lightrag-server=lightrag.api.lightrag_server:main [api]", + "lightrag-ollama=lightrag.api.lightrag_ollama:main [api]", ], }, )