pre-commit run --all-files

This commit is contained in:
yangdx
2025-01-17 14:20:55 +08:00
parent 48f70ff8b4
commit fa9765ecd9
2 changed files with 163 additions and 149 deletions

View File

@@ -24,22 +24,25 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv
load_dotenv()
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
Chinese characters: approximately 1.5 tokens per character
English characters: approximately 0.25 tokens per character
"""
# Use regex to match Chinese and non-Chinese characters separately
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
# Calculate estimated token count
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens)
# Constants for model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = "latest"
@@ -48,6 +51,7 @@ LIGHTRAG_SIZE = 7365960935
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
@@ -61,6 +65,7 @@ async def llm_model_func(
**kwargs,
)
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": "http://m4.lan.znipower.com:11434",
@@ -245,27 +250,32 @@ class SearchMode(str, Enum):
hybrid = "hybrid"
mix = "mix"
# Ollama API compatible models
class OllamaMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel):
model: str = LIGHTRAG_MODEL
messages: List[OllamaMessage]
stream: bool = True # Default to streaming mode
options: Optional[Dict[str, Any]] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
@@ -274,6 +284,7 @@ class OllamaModelDetails(BaseModel):
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
@@ -282,9 +293,11 @@ class OllamaModel(BaseModel):
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: List[OllamaModel]
# Original LightRAG models
class QueryRequest(BaseModel):
query: str
@@ -292,9 +305,11 @@ class QueryRequest(BaseModel):
stream: bool = False
only_need_context: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
@@ -395,7 +410,9 @@ def create_app(args):
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434"
texts,
embed_model="bge-m3:latest",
host="http://m4.lan.znipower.com:11434",
),
),
)
@@ -493,7 +510,7 @@ def create_app(args):
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter
if request.stream:
result = ""
@@ -546,8 +563,8 @@ def create_app(args):
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"X-Accel-Buffering": "no" # 禁用 Nginx 缓冲
}
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -652,29 +669,29 @@ def create_app(args):
@app.get("/api/version")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(
version="0.5.4"
)
return OllamaVersionResponse(version="0.5.4")
@app.get("/api/tags")
async def get_tags():
"""Get available models"""
return OllamaTagResponse(
models=[{
"name": LIGHTRAG_MODEL,
"model": LIGHTRAG_MODEL,
"size": LIGHTRAG_SIZE,
"digest": LIGHTRAG_DIGEST,
"modified_at": LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
"family": LIGHTRAG_NAME,
"families": [LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0"
}
}]
models=[
{
"name": LIGHTRAG_MODEL,
"model": LIGHTRAG_MODEL,
"size": LIGHTRAG_SIZE,
"digest": LIGHTRAG_DIGEST,
"modified_at": LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
"family": LIGHTRAG_NAME,
"families": [LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0",
},
}
]
)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
@@ -686,15 +703,15 @@ def create_app(args):
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid,
"/mix ": SearchMode.mix
"/mix ": SearchMode.mix,
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
# After removing prefix an leading spaces
cleaned_query = query[len(prefix):].lstrip()
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode
return query, SearchMode.hybrid
@app.post("/api/chat")
@@ -705,32 +722,29 @@ def create_app(args):
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query
query = messages[-1].content
# 解析查询模式
cleaned_query, mode = parse_query_mode(query)
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询
query_param = QueryParam(
mode=mode,
stream=request.stream,
only_need_context=False
mode=mode, stream=request.stream, only_need_context=False
)
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.aquery( # Need await to get async generator
cleaned_query,
param=query_param
cleaned_query, param=query_param
)
async def stream_generator():
@@ -738,33 +752,37 @@ def create_app(args):
first_chunk_time = None
last_chunk_time = None
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
# 第一次发送查询内容
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"role": "assistant",
"content": response,
"images": None
"images": None,
},
"done": False
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
# 计算各项指标
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
prompt_eval_time = (
first_chunk_time - start_time
) # 首个响应之前的时间
eval_time = (
last_chunk_time - first_chunk_time
) # 生成响应的时间
# 第二次发送统计信息
data = {
"model": LIGHTRAG_MODEL,
@@ -775,7 +793,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": completion_tokens, # 输出token数
"eval_duration": eval_time # 生成响应的时间
"eval_duration": eval_time, # 生成响应的时间
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
@@ -785,10 +803,10 @@ def create_app(args):
# 记录第一个chunk的时间
if first_chunk_time is None:
first_chunk_time = time.time_ns()
# 更新最后一个chunk的时间
last_chunk_time = time.time_ns()
# 累积响应内容
total_response += chunk
data = {
@@ -797,18 +815,22 @@ def create_app(args):
"message": {
"role": "assistant",
"content": chunk,
"images": None
"images": None,
},
"done": False
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
# 计算各项指标
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
prompt_eval_time = (
first_chunk_time - start_time
) # 首个响应之前的时间
eval_time = (
last_chunk_time - first_chunk_time
) # 生成响应的时间
# 发送完成标记,包含性能统计信息
data = {
"model": LIGHTRAG_MODEL,
@@ -819,14 +841,14 @@ def create_app(args):
"prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": completion_tokens, # 输出token数
"eval_duration": eval_time # 生成响应的时间
"eval_duration": eval_time, # 生成响应的时间
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return # 确保生成器在发送完成标记后立即结束
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
@@ -836,28 +858,25 @@ def create_app(args):
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
# 非流式响应
first_chunk_time = time.time_ns()
response_text = await rag.aquery(
cleaned_query,
param=query_param
)
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
# 确保响应不为空
if not response_text:
response_text = "No response generated"
# 计算各项指标
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time # 总时间
prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间
eval_time = last_chunk_time - first_chunk_time # 生成响应的时间
# 构造响应,包含性能统计信息
return {
"model": LIGHTRAG_MODEL,
@@ -865,7 +884,7 @@ def create_app(args):
"message": {
"role": "assistant",
"content": str(response_text), # 确保转换为字符串
"images": None
"images": None,
},
"done": True,
"total_duration": total_time, # 总时间
@@ -873,7 +892,7 @@ def create_app(args):
"prompt_eval_count": prompt_tokens, # 输入token数
"prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间
"eval_count": completion_tokens, # 输出token数
"eval_duration": eval_time # 生成响应的时间
"eval_duration": eval_time, # 生成响应的时间
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))