完成ollma接口的代码编写
This commit is contained in:
@@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
# from lightrag.llm import lollms_model_complete, lollms_embed
|
|
||||||
# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding
|
|
||||||
from lightrag.llm import openai_complete_if_cache, ollama_embedding
|
from lightrag.llm import openai_complete_if_cache, ollama_embedding
|
||||||
# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
|
|
||||||
|
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from typing import Optional, List
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
@@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Constants for model information
|
||||||
|
LIGHTRAG_NAME = "lightrag"
|
||||||
|
LIGHTRAG_TAG = "latest"
|
||||||
|
LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||||
|
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||||
|
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -219,21 +223,43 @@ class DocumentManager:
|
|||||||
class SearchMode(str, Enum):
|
class SearchMode(str, Enum):
|
||||||
naive = "naive"
|
naive = "naive"
|
||||||
local = "local"
|
local = "local"
|
||||||
global_ = "global"
|
global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global"
|
||||||
hybrid = "hybrid"
|
hybrid = "hybrid"
|
||||||
|
|
||||||
|
# Ollama API compatible models
|
||||||
|
class OllamaMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class OllamaChatRequest(BaseModel):
|
||||||
|
model: str = LIGHTRAG_MODEL
|
||||||
|
messages: List[OllamaMessage]
|
||||||
|
stream: bool = False
|
||||||
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
class OllamaChatResponse(BaseModel):
|
||||||
|
model: str
|
||||||
|
created_at: str
|
||||||
|
message: OllamaMessage
|
||||||
|
done: bool
|
||||||
|
|
||||||
|
class OllamaVersionResponse(BaseModel):
|
||||||
|
version: str
|
||||||
|
build: str = "default"
|
||||||
|
|
||||||
|
class OllamaTagResponse(BaseModel):
|
||||||
|
models: List[Dict[str, str]]
|
||||||
|
|
||||||
|
# Original LightRAG models
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
mode: SearchMode = SearchMode.hybrid
|
mode: SearchMode = SearchMode.hybrid
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
|
|
||||||
|
|
||||||
class InsertTextRequest(BaseModel):
|
class InsertTextRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@@ -555,6 +581,101 @@ def create_app(args):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Ollama compatible API endpoints
|
||||||
|
@app.get("/api/version")
|
||||||
|
async def get_version():
|
||||||
|
"""Get Ollama version information"""
|
||||||
|
return OllamaVersionResponse(
|
||||||
|
version="0.1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/tags")
|
||||||
|
async def get_tags():
|
||||||
|
"""Get available models"""
|
||||||
|
return OllamaTagResponse(
|
||||||
|
models=[{
|
||||||
|
"name": LIGHTRAG_NAME,
|
||||||
|
"tag": LIGHTRAG_TAG,
|
||||||
|
"size": 0,
|
||||||
|
"digest": LIGHTRAG_DIGEST,
|
||||||
|
"modified_at": LIGHTRAG_CREATED_AT
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||||
|
"""Parse query prefix to determine search mode
|
||||||
|
Returns tuple of (cleaned_query, search_mode)
|
||||||
|
"""
|
||||||
|
mode_map = {
|
||||||
|
"/local ": SearchMode.local,
|
||||||
|
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
||||||
|
"/naive ": SearchMode.naive,
|
||||||
|
"/hybrid ": SearchMode.hybrid
|
||||||
|
}
|
||||||
|
|
||||||
|
for prefix, mode in mode_map.items():
|
||||||
|
if query.startswith(prefix):
|
||||||
|
return query[len(prefix):], mode
|
||||||
|
|
||||||
|
return query, SearchMode.hybrid
|
||||||
|
|
||||||
|
@app.post("/api/chat")
|
||||||
|
async def chat(request: OllamaChatRequest):
|
||||||
|
"""Handle chat completion requests"""
|
||||||
|
try:
|
||||||
|
# Convert chat format to query
|
||||||
|
query = request.messages[-1].content if request.messages else ""
|
||||||
|
|
||||||
|
# Parse query mode and clean query
|
||||||
|
cleaned_query, mode = parse_query_mode(query)
|
||||||
|
|
||||||
|
# Call RAG with determined mode
|
||||||
|
response = await rag.aquery(
|
||||||
|
cleaned_query,
|
||||||
|
param=QueryParam(
|
||||||
|
mode=mode,
|
||||||
|
stream=request.stream
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
async def stream_generator():
|
||||||
|
result = ""
|
||||||
|
async for chunk in response:
|
||||||
|
result += chunk
|
||||||
|
yield OllamaChatResponse(
|
||||||
|
model=LIGHTRAG_MODEL,
|
||||||
|
created_at=LIGHTRAG_CREATED_AT,
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=chunk
|
||||||
|
),
|
||||||
|
done=False
|
||||||
|
)
|
||||||
|
# Send final message
|
||||||
|
yield OllamaChatResponse(
|
||||||
|
model=LIGHTRAG_MODEL,
|
||||||
|
created_at=LIGHTRAG_CREATED_AT,
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=result
|
||||||
|
),
|
||||||
|
done=True
|
||||||
|
)
|
||||||
|
return stream_generator()
|
||||||
|
else:
|
||||||
|
return OllamaChatResponse(
|
||||||
|
model=LIGHTRAG_MODEL,
|
||||||
|
created_at=LIGHTRAG_CREATED_AT,
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=response
|
||||||
|
),
|
||||||
|
done=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||||
async def get_status():
|
async def get_status():
|
||||||
"""Get current system status"""
|
"""Get current system status"""
|
||||||
|
1
setup.py
1
setup.py
@@ -101,6 +101,7 @@ setuptools.setup(
|
|||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"lightrag-server=lightrag.api.lightrag_server:main [api]",
|
"lightrag-server=lightrag.api.lightrag_server:main [api]",
|
||||||
|
"lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user