修改流处理逻辑,修改 /api/tags错误
This commit is contained in:
@@ -26,7 +26,7 @@ load_dotenv()
|
||||
# Constants for model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = "latest"
|
||||
LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_MODEL = "lightrag:latest"
|
||||
LIGHTRAG_SIZE = 7365960935
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
@@ -247,8 +247,25 @@ class OllamaChatResponse(BaseModel):
|
||||
class OllamaVersionResponse(BaseModel):
|
||||
version: str
|
||||
|
||||
class OllamaModelDetails(BaseModel):
|
||||
parent_model: str
|
||||
format: str
|
||||
family: str
|
||||
families: List[str]
|
||||
parameter_size: str
|
||||
quantization_level: str
|
||||
|
||||
class OllamaModel(BaseModel):
|
||||
name: str
|
||||
model: str
|
||||
tag: str
|
||||
size: int
|
||||
digest: str
|
||||
modified_at: str
|
||||
details: OllamaModelDetails
|
||||
|
||||
class OllamaTagResponse(BaseModel):
|
||||
models: List[Dict[str, str]]
|
||||
models: List[OllamaModel]
|
||||
|
||||
# Original LightRAG models
|
||||
class QueryRequest(BaseModel):
|
||||
@@ -632,26 +649,46 @@ def create_app(args):
|
||||
async def chat(request: OllamaChatRequest):
|
||||
"""Handle chat completion requests"""
|
||||
try:
|
||||
# Convert chat format to query
|
||||
query = request.messages[-1].content if request.messages else ""
|
||||
# 获取所有消息内容
|
||||
messages = request.messages
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="No messages provided")
|
||||
|
||||
# Parse query mode and clean query
|
||||
# 获取最后一条消息作为查询
|
||||
query = messages[-1].content
|
||||
|
||||
# 解析查询模式
|
||||
cleaned_query, mode = parse_query_mode(query)
|
||||
|
||||
# Call RAG with determined mode
|
||||
response = await rag.aquery(
|
||||
cleaned_query,
|
||||
param=QueryParam(
|
||||
# 构建系统提示词(如果有历史消息)
|
||||
system_prompt = None
|
||||
history_messages = []
|
||||
if len(messages) > 1:
|
||||
# 如果第一条消息是系统消息,提取为system_prompt
|
||||
if messages[0].role == "system":
|
||||
system_prompt = messages[0].content
|
||||
messages = messages[1:]
|
||||
|
||||
# 收集历史消息(除了最后一条)
|
||||
history_messages = [(msg.role, msg.content) for msg in messages[:-1]]
|
||||
|
||||
# 调用RAG进行查询
|
||||
kwargs = {
|
||||
"param": QueryParam(
|
||||
mode=mode,
|
||||
stream=request.stream
|
||||
stream=request.stream,
|
||||
)
|
||||
)
|
||||
}
|
||||
if system_prompt is not None:
|
||||
kwargs["system_prompt"] = system_prompt
|
||||
if history_messages:
|
||||
kwargs["history_messages"] = history_messages
|
||||
|
||||
response = await rag.aquery(cleaned_query, **kwargs)
|
||||
|
||||
if request.stream:
|
||||
async def stream_generator():
|
||||
result = ""
|
||||
async for chunk in response:
|
||||
result += chunk
|
||||
yield OllamaChatResponse(
|
||||
model=LIGHTRAG_MODEL,
|
||||
created_at=LIGHTRAG_CREATED_AT,
|
||||
@@ -661,13 +698,13 @@ def create_app(args):
|
||||
),
|
||||
done=False
|
||||
)
|
||||
# Send final message
|
||||
# 发送一个空的完成消息
|
||||
yield OllamaChatResponse(
|
||||
model=LIGHTRAG_MODEL,
|
||||
created_at=LIGHTRAG_CREATED_AT,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=result
|
||||
content=""
|
||||
),
|
||||
done=True
|
||||
)
|
||||
|
Reference in New Issue
Block a user