Fix linting, remove redundant commentsr and clean up code for better readability

This commit is contained in:
yangdx
2025-01-24 23:50:47 +08:00
parent 11873625a3
commit f30a69e201
2 changed files with 76 additions and 87 deletions

View File

@@ -476,6 +476,7 @@ class OllamaChatResponse(BaseModel):
message: OllamaMessage
done: bool
class OllamaGenerateRequest(BaseModel):
model: str = LIGHTRAG_MODEL
prompt: str
@@ -483,6 +484,7 @@ class OllamaGenerateRequest(BaseModel):
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
@@ -490,12 +492,13 @@ class OllamaGenerateResponse(BaseModel):
done: bool
context: Optional[List[int]]
total_duration: Optional[int]
load_duration: Optional[int]
load_duration: Optional[int]
prompt_eval_count: Optional[int]
prompt_eval_duration: Optional[int]
eval_count: Optional[int]
eval_duration: Optional[int]
class OllamaVersionResponse(BaseModel):
version: str
@@ -1262,52 +1265,45 @@ def create_app(args):
"""Handle generate completion requests"""
try:
query = request.prompt
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(query)
# 直接使用 llm_model_func 进行查询
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.llm_model_func(
query,
stream=True,
**rag.llm_model_kwargs
query, stream=True, **rag.llm_model_kwargs
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# 处理响应
# Ensure response is an async generator
if isinstance(response, str):
# 如果是字符串,分两部分发送
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"response": response,
"done": False
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
@@ -1317,7 +1313,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
@@ -1325,23 +1321,23 @@ def create_app(args):
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
@@ -1351,15 +1347,15 @@ def create_app(args):
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
@@ -1375,20 +1371,18 @@ def create_app(args):
else:
first_chunk_time = time.time_ns()
response_text = await rag.llm_model_func(
query,
stream=False,
**rag.llm_model_kwargs
query, stream=False, **rag.llm_model_kwargs
)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
@@ -1399,7 +1393,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
@@ -1417,16 +1411,12 @@ def create_app(args):
# Get the last message as query
query = messages[-1].content
# 解析查询模式
# Check for query prefix
cleaned_query, mode = parse_query_mode(query)
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询
query_param = QueryParam(
mode=mode, stream=request.stream, only_need_context=False
)
@@ -1537,25 +1527,21 @@ def create_app(args):
)
else:
first_chunk_time = time.time_ns()
# 判断是否包含特定字符串,使用正则表达式进行匹配
logging.info(f"Cleaned query content: {cleaned_query}")
match_result = re.search(r'\n<chat_history>\nUSER:', cleaned_query, re.MULTILINE)
logging.info(f"Regex match result: {bool(match_result)}")
if match_result:
# Determine if the request is from Open WebUI's session title and session keyword generation task
match_result = re.search(
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
)
if match_result:
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await rag.llm_model_func(
cleaned_query,
stream=False,
**rag.llm_model_kwargs
cleaned_query, stream=False, **rag.llm_model_kwargs
)
else:
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text: