完成ollma接口的代码编写

This commit is contained in:
yangdx
2025-01-15 14:31:49 +08:00
parent b97d1ecd72
commit be134878fe
2 changed files with 129 additions and 7 deletions

View File

@@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
import argparse import argparse
from typing import List, Dict, Any, Optional
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
# from lightrag.llm import lollms_model_complete, lollms_embed
# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding
from lightrag.llm import openai_complete_if_cache, ollama_embedding from lightrag.llm import openai_complete_if_cache, ollama_embedding
# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
import shutil import shutil
@@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Constants for model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = "latest"
LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str: ) -> str:
@@ -219,21 +223,43 @@ class DocumentManager:
class SearchMode(str, Enum): class SearchMode(str, Enum):
naive = "naive" naive = "naive"
local = "local" local = "local"
global_ = "global" global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global"
hybrid = "hybrid" hybrid = "hybrid"
# Ollama API compatible models
class OllamaMessage(BaseModel):
role: str
content: str
class OllamaChatRequest(BaseModel):
model: str = LIGHTRAG_MODEL
messages: List[OllamaMessage]
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaVersionResponse(BaseModel):
version: str
build: str = "default"
class OllamaTagResponse(BaseModel):
models: List[Dict[str, str]]
# Original LightRAG models
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
mode: SearchMode = SearchMode.hybrid mode: SearchMode = SearchMode.hybrid
stream: bool = False stream: bool = False
only_need_context: bool = False only_need_context: bool = False
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str response: str
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str text: str
description: Optional[str] = None description: Optional[str] = None
@@ -555,6 +581,101 @@ def create_app(args):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# Ollama compatible API endpoints
@app.get("/api/version")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(
version="0.1.0"
)
@app.get("/api/tags")
async def get_tags():
"""Get available models"""
return OllamaTagResponse(
models=[{
"name": LIGHTRAG_NAME,
"tag": LIGHTRAG_TAG,
"size": 0,
"digest": LIGHTRAG_DIGEST,
"modified_at": LIGHTRAG_CREATED_AT
}]
)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
"""
mode_map = {
"/local ": SearchMode.local,
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
return query[len(prefix):], mode
return query, SearchMode.hybrid
@app.post("/api/chat")
async def chat(request: OllamaChatRequest):
"""Handle chat completion requests"""
try:
# Convert chat format to query
query = request.messages[-1].content if request.messages else ""
# Parse query mode and clean query
cleaned_query, mode = parse_query_mode(query)
# Call RAG with determined mode
response = await rag.aquery(
cleaned_query,
param=QueryParam(
mode=mode,
stream=request.stream
)
)
if request.stream:
async def stream_generator():
result = ""
async for chunk in response:
result += chunk
yield OllamaChatResponse(
model=LIGHTRAG_MODEL,
created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage(
role="assistant",
content=chunk
),
done=False
)
# Send final message
yield OllamaChatResponse(
model=LIGHTRAG_MODEL,
created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage(
role="assistant",
content=result
),
done=True
)
return stream_generator()
else:
return OllamaChatResponse(
model=LIGHTRAG_MODEL,
created_at=LIGHTRAG_CREATED_AT,
message=OllamaMessage(
role="assistant",
content=response
),
done=True
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)]) @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""

View File

@@ -101,6 +101,7 @@ setuptools.setup(
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"lightrag-server=lightrag.api.lightrag_server:main [api]", "lightrag-server=lightrag.api.lightrag_server:main [api]",
"lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
], ],
}, },
) )