draft implementation of /api/generate endpoint
This commit is contained in:
@@ -475,6 +475,25 @@ class OllamaChatResponse(BaseModel):
|
|||||||
message: OllamaMessage
|
message: OllamaMessage
|
||||||
done: bool
|
done: bool
|
||||||
|
|
||||||
|
class OllamaGenerateRequest(BaseModel):
|
||||||
|
model: str = LIGHTRAG_MODEL
|
||||||
|
prompt: str
|
||||||
|
system: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
class OllamaGenerateResponse(BaseModel):
|
||||||
|
model: str
|
||||||
|
created_at: str
|
||||||
|
response: str
|
||||||
|
done: bool
|
||||||
|
context: Optional[List[int]]
|
||||||
|
total_duration: Optional[int]
|
||||||
|
load_duration: Optional[int]
|
||||||
|
prompt_eval_count: Optional[int]
|
||||||
|
prompt_eval_duration: Optional[int]
|
||||||
|
eval_count: Optional[int]
|
||||||
|
eval_duration: Optional[int]
|
||||||
|
|
||||||
class OllamaVersionResponse(BaseModel):
|
class OllamaVersionResponse(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
@@ -1237,6 +1256,160 @@ def create_app(args):
|
|||||||
|
|
||||||
return query, SearchMode.hybrid
|
return query, SearchMode.hybrid
|
||||||
|
|
||||||
|
@app.post("/api/generate")
|
||||||
|
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||||
|
"""Handle generate completion requests"""
|
||||||
|
try:
|
||||||
|
# 获取查询内容
|
||||||
|
query = request.prompt
|
||||||
|
|
||||||
|
# 解析查询模式
|
||||||
|
cleaned_query, mode = parse_query_mode(query)
|
||||||
|
|
||||||
|
# 开始计时
|
||||||
|
start_time = time.time_ns()
|
||||||
|
|
||||||
|
# 计算输入token数量
|
||||||
|
prompt_tokens = estimate_tokens(cleaned_query)
|
||||||
|
|
||||||
|
# 调用RAG进行查询
|
||||||
|
query_param = QueryParam(
|
||||||
|
mode=mode,
|
||||||
|
stream=request.stream,
|
||||||
|
only_need_context=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有 system prompt,更新 rag 的 llm_model_kwargs
|
||||||
|
if request.system:
|
||||||
|
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
response = await rag.aquery(
|
||||||
|
cleaned_query,
|
||||||
|
param=query_param
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stream_generator():
|
||||||
|
try:
|
||||||
|
first_chunk_time = None
|
||||||
|
last_chunk_time = None
|
||||||
|
total_response = ""
|
||||||
|
|
||||||
|
# 处理响应
|
||||||
|
if isinstance(response, str):
|
||||||
|
# 如果是字符串,分两部分发送
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
last_chunk_time = first_chunk_time
|
||||||
|
total_response = response
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"response": response,
|
||||||
|
"done": False
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
else:
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk:
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
total_response += chunk
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"response": chunk,
|
||||||
|
"done": False
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in stream_generator: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_generator(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "application/x-ndjson",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
if not response_text:
|
||||||
|
response_text = "No response generated"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(str(response_text))
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"response": str(response_text),
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
trace_exception(e)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post("/api/chat")
|
@app.post("/api/chat")
|
||||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||||
"""Handle chat completion requests"""
|
"""Handle chat completion requests"""
|
||||||
|
Reference in New Issue
Block a user