diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 16f3ae49..d417d732 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -475,6 +475,25 @@ class OllamaChatResponse(BaseModel): message: OllamaMessage done: bool +class OllamaGenerateRequest(BaseModel): + model: str = LIGHTRAG_MODEL + prompt: str + system: Optional[str] = None + stream: bool = False + options: Optional[Dict[str, Any]] = None + +class OllamaGenerateResponse(BaseModel): + model: str + created_at: str + response: str + done: bool + context: Optional[List[int]] + total_duration: Optional[int] + load_duration: Optional[int] + prompt_eval_count: Optional[int] + prompt_eval_duration: Optional[int] + eval_count: Optional[int] + eval_duration: Optional[int] class OllamaVersionResponse(BaseModel): version: str @@ -1237,6 +1256,160 @@ def create_app(args): return query, SearchMode.hybrid + @app.post("/api/generate") + async def generate(raw_request: Request, request: OllamaGenerateRequest): + """Handle generate completion requests""" + try: + # 获取查询内容 + query = request.prompt + + # 解析查询模式 + cleaned_query, mode = parse_query_mode(query) + + # 开始计时 + start_time = time.time_ns() + + # 计算输入token数量 + prompt_tokens = estimate_tokens(cleaned_query) + + # 调用RAG进行查询 + query_param = QueryParam( + mode=mode, + stream=request.stream, + only_need_context=False + ) + + # 如果有 system prompt,更新 rag 的 llm_model_kwargs + if request.system: + rag.llm_model_kwargs["system_prompt"] = request.system + + if request.stream: + from fastapi.responses import StreamingResponse + + response = await rag.aquery( + cleaned_query, + param=query_param + ) + + async def stream_generator(): + try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + + # 处理响应 + if isinstance(response, str): + # 如果是字符串,分两部分发送 + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": response, + "done": False + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return + + except Exception as e: + logging.error(f"Error in stream_generator: {str(e)}") + raise + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + response_text = await rag.aquery(cleaned_query, param=query_param) + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "response": str(response_text), + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/api/chat") async def chat(raw_request: Request, request: OllamaChatRequest): """Handle chat completion requests"""