完成ollma接口的代码编写
This commit is contained in:
@@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import argparse
|
||||
from typing import List, Dict, Any, Optional
|
||||
from lightrag import LightRAG, QueryParam
|
||||
# from lightrag.llm import lollms_model_complete, lollms_embed
|
||||
# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding
|
||||
from lightrag.llm import openai_complete_if_cache, ollama_embedding
|
||||
# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
|
||||
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
@@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Constants for model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = "latest"
|
||||
LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
@@ -219,21 +223,43 @@ class DocumentManager:
|
||||
class SearchMode(str, Enum):
|
||||
naive = "naive"
|
||||
local = "local"
|
||||
global_ = "global"
|
||||
global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global"
|
||||
hybrid = "hybrid"
|
||||
|
||||
# Ollama API compatible models
|
||||
class OllamaMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class OllamaChatRequest(BaseModel):
|
||||
model: str = LIGHTRAG_MODEL
|
||||
messages: List[OllamaMessage]
|
||||
stream: bool = False
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
class OllamaChatResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
message: OllamaMessage
|
||||
done: bool
|
||||
|
||||
class OllamaVersionResponse(BaseModel):
|
||||
version: str
|
||||
build: str = "default"
|
||||
|
||||
class OllamaTagResponse(BaseModel):
|
||||
models: List[Dict[str, str]]
|
||||
|
||||
# Original LightRAG models
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
mode: SearchMode = SearchMode.hybrid
|
||||
stream: bool = False
|
||||
only_need_context: bool = False
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
class InsertTextRequest(BaseModel):
|
||||
text: str
|
||||
description: Optional[str] = None
|
||||
@@ -555,6 +581,101 @@ def create_app(args):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Ollama compatible API endpoints
|
||||
@app.get("/api/version")
|
||||
async def get_version():
|
||||
"""Get Ollama version information"""
|
||||
return OllamaVersionResponse(
|
||||
version="0.1.0"
|
||||
)
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def get_tags():
|
||||
"""Get available models"""
|
||||
return OllamaTagResponse(
|
||||
models=[{
|
||||
"name": LIGHTRAG_NAME,
|
||||
"tag": LIGHTRAG_TAG,
|
||||
"size": 0,
|
||||
"digest": LIGHTRAG_DIGEST,
|
||||
"modified_at": LIGHTRAG_CREATED_AT
|
||||
}]
|
||||
)
|
||||
|
||||
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||
"""Parse query prefix to determine search mode
|
||||
Returns tuple of (cleaned_query, search_mode)
|
||||
"""
|
||||
mode_map = {
|
||||
"/local ": SearchMode.local,
|
||||
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
||||
"/naive ": SearchMode.naive,
|
||||
"/hybrid ": SearchMode.hybrid
|
||||
}
|
||||
|
||||
for prefix, mode in mode_map.items():
|
||||
if query.startswith(prefix):
|
||||
return query[len(prefix):], mode
|
||||
|
||||
return query, SearchMode.hybrid
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: OllamaChatRequest):
|
||||
"""Handle chat completion requests"""
|
||||
try:
|
||||
# Convert chat format to query
|
||||
query = request.messages[-1].content if request.messages else ""
|
||||
|
||||
# Parse query mode and clean query
|
||||
cleaned_query, mode = parse_query_mode(query)
|
||||
|
||||
# Call RAG with determined mode
|
||||
response = await rag.aquery(
|
||||
cleaned_query,
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
stream=request.stream
|
||||
)
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def stream_generator():
|
||||
result = ""
|
||||
async for chunk in response:
|
||||
result += chunk
|
||||
yield OllamaChatResponse(
|
||||
model=LIGHTRAG_MODEL,
|
||||
created_at=LIGHTRAG_CREATED_AT,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=chunk
|
||||
),
|
||||
done=False
|
||||
)
|
||||
# Send final message
|
||||
yield OllamaChatResponse(
|
||||
model=LIGHTRAG_MODEL,
|
||||
created_at=LIGHTRAG_CREATED_AT,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=result
|
||||
),
|
||||
done=True
|
||||
)
|
||||
return stream_generator()
|
||||
else:
|
||||
return OllamaChatResponse(
|
||||
model=LIGHTRAG_MODEL,
|
||||
created_at=LIGHTRAG_CREATED_AT,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=response
|
||||
),
|
||||
done=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
|
Reference in New Issue
Block a user