完成ollma接口的代码编写

This commit is contained in:
yangdx
2025-01-15 14:31:49 +08:00
parent b97d1ecd72
commit be134878fe
2 changed files with 129 additions and 7 deletions

View File

@@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel
import logging
import argparse
from typing import List, Dict, Any, Optional
from lightrag import LightRAG, QueryParam
# from lightrag.llm import lollms_model_complete, lollms_embed
# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding
from lightrag.llm import openai_complete_if_cache, ollama_embedding
# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum
from pathlib import Path
import shutil
@@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv
load_dotenv()
# Constants for model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = "latest"
LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
@@ -219,21 +223,43 @@ class DocumentManager:
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global"
hybrid = "hybrid"
# Ollama API compatible models
class OllamaMessage(BaseModel):
role: str
content: str
class OllamaChatRequest(BaseModel):
model: str = LIGHTRAG_MODEL
messages: List[OllamaMessage]
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaVersionResponse(BaseModel):
version: str
build: str = "default"
class OllamaTagResponse(BaseModel):
models: List[Dict[str, str]]
# Original LightRAG models
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
@@ -555,6 +581,101 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Ollama compatible API endpoints
@app.get("/api/version")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(
version="0.1.0"
)
@app.get("/api/tags")
async def get_tags():
"""Get available models"""
return OllamaTagResponse(
models=[{
"name": LIGHTRAG_NAME,
"tag": LIGHTRAG_TAG,
"size": 0,
"digest": LIGHTRAG_DIGEST,
"modified_at": LIGHTRAG_CREATED_AT
}]
)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
"""
mode_map = {
"/local ": SearchMode.local,
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
return query[len(prefix):], mode
return query, SearchMode.hybrid
@app.post("/api/chat")
async def chat(request: OllamaChatRequest):
"""Handle chat completion requests"""
try:
# Convert chat format to query
query = request.messages[-1].content if request.messages else ""
# Parse query mode and clean query
cleaned_query, mode = parse_query_mode(query)
# Call RAG with determined mode
response = await rag.aquery(
cleaned_query,
param=QueryParam(
mode=mode,
stream=request.stream
)
)
if request.stream:
async def stream_generator():
result = ""
async for chunk in response:
result += chunk
yield OllamaChatResponse(
model=LIGHTRAG_MODEL,
created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage(
role="assistant",
content=chunk
),
done=False
)
# Send final message
yield OllamaChatResponse(
model=LIGHTRAG_MODEL,
created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage(
role="assistant",
content=result
),
done=True
)
return stream_generator()
else:
return OllamaChatResponse(
model=LIGHTRAG_MODEL,
created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage(
role="assistant",
content=response
),
done=True
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""

View File

@@ -101,6 +101,7 @@ setuptools.setup(
entry_points={
"console_scripts": [
"lightrag-server=lightrag.api.lightrag_server:main [api]",
"lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
],
},
)