From 0bfeb237e38d0de5789c00b91022f51e6820dacc Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 14 Jan 2025 23:04:41 +0800 Subject: [PATCH 01/42] =?UTF-8?q?=E5=88=9B=E5=BB=BAyangdx=E5=88=86?= =?UTF-8?q?=E6=94=AF=EF=BC=8C=E5=B9=B6=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 3 +- examples/lightrag_yangdx.py | 70 +++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 examples/lightrag_yangdx.py diff --git a/README.md b/README.md index 90c3ec04..f1aeed99 100644 --- a/README.md +++ b/README.md @@ -695,7 +695,7 @@ Output the results in the following structure: ``` - ### Batch Eval +### Batch Eval To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
@@ -746,6 +746,7 @@ Output your evaluation in the following JSON format:
### Overall Performance Table + | | **Agriculture** | | **CS** | | **Legal** | | **Mix** | | |----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------| | | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | diff --git a/examples/lightrag_yangdx.py b/examples/lightrag_yangdx.py new file mode 100644 index 00000000..162900c4 --- /dev/null +++ b/examples/lightrag_yangdx.py @@ -0,0 +1,70 @@ +import asyncio +import os +import inspect +import logging +from lightrag import LightRAG, QueryParam +from lightrag.llm import ollama_model_complete, ollama_embedding +from lightrag.utils import EmbeddingFunc + +WORKING_DIR = "./dickens" + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=ollama_model_complete, + llm_model_name="gemma2:2b", + llm_model_max_async=4, + llm_model_max_token_size=32768, + llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}}, + embedding_func=EmbeddingFunc( + embedding_dim=768, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, embed_model="nomic-embed-text", host="http://localhost:11434" + ), + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) + +# stream response +resp = rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), +) + + +async def print_stream(stream): + async for chunk in stream: + print(chunk, end="", flush=True) + + +if inspect.isasyncgen(resp): + asyncio.run(print_stream(resp)) +else: + print(resp) From 294b0359e89ae634bfd8acf726dda2a722bf0b99 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 00:55:48 +0800 Subject: [PATCH 02/42] =?UTF-8?q?=E4=BF=AE=E6=94=B9llm=E4=B8=BAdeepseek-ch?= =?UTF-8?q?at?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/lightrag_yangdx.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/examples/lightrag_yangdx.py b/examples/lightrag_yangdx.py index 162900c4..72f42cc4 100644 --- a/examples/lightrag_yangdx.py +++ b/examples/lightrag_yangdx.py @@ -2,34 +2,46 @@ import asyncio import os import inspect import logging +from dotenv import load_dotenv from lightrag import LightRAG, QueryParam -from lightrag.llm import ollama_model_complete, ollama_embedding +from lightrag.llm import openai_complete_if_cache, ollama_embedding from lightrag.utils import EmbeddingFunc -WORKING_DIR = "./dickens" +load_dotenv() + +WORKING_DIR = "./examples/input" logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + "deepseek-chat", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("DEEPSEEK_API_KEY"), + base_url=os.getenv("DEEPSEEK__ENDPOINT"), + **kwargs, + ) + if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) rag = LightRAG( working_dir=WORKING_DIR, - llm_model_func=ollama_model_complete, - llm_model_name="gemma2:2b", - llm_model_max_async=4, - llm_model_max_token_size=32768, - llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}}, + llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=768, max_token_size=8192, func=lambda texts: ollama_embedding( - texts, embed_model="nomic-embed-text", host="http://localhost:11434" + texts, embed_model="nomic-embed-text", host="http://m4.lan.znipower.com:11434" ), ), ) -with open("./book.txt", "r", encoding="utf-8") as f: +with open("./input/book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search From b11c33d7a1b1987873885b817ccd0c6a6050bd69 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 01:25:49 +0800 Subject: [PATCH 03/42] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dembedding=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=BA=AC=E5=BA=A6=E6=95=B0=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/lightrag_yangdx.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/lightrag_yangdx.py b/examples/lightrag_yangdx.py index 72f42cc4..eeb651f8 100644 --- a/examples/lightrag_yangdx.py +++ b/examples/lightrag_yangdx.py @@ -9,7 +9,7 @@ from lightrag.utils import EmbeddingFunc load_dotenv() -WORKING_DIR = "./examples/input" +WORKING_DIR = "./examples/output" logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) @@ -22,7 +22,7 @@ async def llm_model_func( system_prompt=system_prompt, history_messages=history_messages, api_key=os.getenv("DEEPSEEK_API_KEY"), - base_url=os.getenv("DEEPSEEK__ENDPOINT"), + base_url=os.getenv("DEEPSEEK_ENDPOINT"), **kwargs, ) @@ -33,15 +33,15 @@ rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( - embedding_dim=768, + embedding_dim=1024, max_token_size=8192, func=lambda texts: ollama_embedding( - texts, embed_model="nomic-embed-text", host="http://m4.lan.znipower.com:11434" + texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" ), ), ) -with open("./input/book.txt", "r", encoding="utf-8") as f: +with open("./examples/input/book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search From 33e211789e5c51943e3fc37303a6ab5b7af6f2ab Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 01:35:28 +0800 Subject: [PATCH 04/42] =?UTF-8?q?=E5=88=A0=E9=99=A4=20.DS=5FStore=20?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .DS_Store | Bin 8196 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 7489d923a9e7375a0aadb3d45e73e969d67f761d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMTWl0n7(U;$(3#;jz3G5m*ica{3$&KdV$p0{3Y034-f7Wwc4wqxr!!?|b_)_> zjrt@S6K~N(LcBcqAc-1CG-@>AfvEApG)7G1#Y7+U!Ap#qVDLY4W(jTK!GstR<|OC* z^Pm5G=klNLpR;F{F@}zU*~FO67?UY;sV$=73W?i!o|mMkrko@Q&zQwr=4Yq;$*eQ4 zbzBj9Aof7)f!G7F2VxIg4<4X1nJAnagw}r-ck(9aMx@ z0FwL)5GH!2dq9XuMlv1AX(2->^(k%-h@gnB7!Yu>CwX&{=}1lsDc}qO&Jc}^=!SxD zbiyUW<_u{e<1+R@?1A|n;NerjOlGqj8^8Si?q$jHes&XC>4t; z?(ZA(ruqUSDEJ*lew1tdp66ss&qj98wx%*#O{eDuhV5GU5MY}YHQ7I8yQVkQ;T2rd z53zm1peVAE(Uu=M(%iHqm27TlIhslyX=-e3P9-Zv$A9Jm*x1y&aA_iC0!^j>zim7k zjomrV-J7?#eJJNSg$dWsXq7p`$rT(Uu-YA`&wj-6hXdgYle!pZ)Y8AV9GXw zv3-W0vs}}54~&ol-N1gxw(?3PPAL7ZF=5SUOP4LV>1JIYE@bV41v@yb)yuNzt4?33 zXs6`;CXIieXL6(&TJ$=7wJeYF`p&#XrKncd=(j7HA@}>V1kpu-%_T?WSp&(aQC` zdatZ~MHtO?i#Ac1wpTD&u1`0Joe14mpKcVILI=}wwt~?&+HHOC;N;24GCpffQ6_=9THfH1~g+6Hlqbw(S_~UiEix1 z5QZ^=Q5bj_)`S8ZY7{yo^`y7T(2sIEB;r6rbU9e1Y@$5x?R${EiFw z2NxwKRZBHeLRu}Y6VsMTkuf}*UQ&wmIRQ_RK_BeG$T)bV$VhD8vGe+fjB|^5EL2ui z*VL}$0kXL*eThbbcqOh~&f$=@v&bIOHr(gAe4~h2zTy@&F+jYw4$Y0UQUq!hoWb%8 zZ@nt1CbR{_e4*aDhSY_`dZFIYkVt3}1)ESeY)q(H8HJJ(Vw)RCMWnQ}s;X8_K_zTk z?toeTRL(K4E9s3{mw5_7nSsU4)FKs3BtBhP7CS6xI_tTX7fK zLW1rfdhS9G`Y}ip-A5$N!NLKI!A2ed3L#M+3yJz9p24$t4$tES9LEW~hS%{1-o!~F z?T2&dxN9~Y@zH!b7W?w9dqk Date: Wed, 15 Jan 2025 02:25:01 +0800 Subject: [PATCH 05/42] =?UTF-8?q?=E5=8E=BB=E6=8E=89=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E5=92=8C=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/lightrag_yangdx.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/lightrag_yangdx.py b/examples/lightrag_yangdx.py index eeb651f8..2691deac 100644 --- a/examples/lightrag_yangdx.py +++ b/examples/lightrag_yangdx.py @@ -1,6 +1,6 @@ -import asyncio import os -import inspect +# import asyncio +# import inspect import logging from dotenv import load_dotenv from lightrag import LightRAG, QueryParam @@ -64,19 +64,19 @@ print( rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) ) -# stream response -resp = rag.query( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid", stream=True), -) +# # stream response +# resp = rag.query( +# "What are the top themes in this story?", +# param=QueryParam(mode="hybrid", stream=True), +# ) -async def print_stream(stream): - async for chunk in stream: - print(chunk, end="", flush=True) +# async def print_stream(stream): +# async for chunk in stream: +# print(chunk, end="", flush=True) -if inspect.isasyncgen(resp): - asyncio.run(print_stream(resp)) -else: - print(resp) +# if inspect.isasyncgen(resp): +# asyncio.run(print_stream(resp)) +# else: +# print(resp) From 1088e10fb245b76ab02bbca27849a1368533ce39 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 10:44:12 +0800 Subject: [PATCH 06/42] =?UTF-8?q?=E4=BF=AE=E6=94=B9lightrag=5Fserver?= =?UTF-8?q?=E7=9A=84LLM=E5=92=8CEmbedding=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_server.py | 69 +++++++++++++-------------------- lightrag/api/start.sh | 1 + 2 files changed, 27 insertions(+), 43 deletions(-) create mode 100755 lightrag/api/start.sh diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5bcb149c..42ae68f4 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -3,10 +3,10 @@ from pydantic import BaseModel import logging import argparse from lightrag import LightRAG, QueryParam -from lightrag.llm import lollms_model_complete, lollms_embed -from lightrag.llm import ollama_model_complete, ollama_embed -from lightrag.llm import openai_complete_if_cache, openai_embedding -from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding +# from lightrag.llm import lollms_model_complete, lollms_embed +# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding +from lightrag.llm import openai_complete_if_cache, ollama_embedding +# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding from lightrag.utils import EmbeddingFunc from typing import Optional, List @@ -23,13 +23,28 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN +from dotenv import load_dotenv +load_dotenv() + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + "deepseek-chat", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("DEEPSEEK_API_KEY"), + base_url=os.getenv("DEEPSEEK_ENDPOINT"), + **kwargs, + ) def get_default_host(binding_type: str) -> str: default_hosts = { - "ollama": "http://localhost:11434", + "ollama": "http://m4.lan.znipower.com:11434", "lollms": "http://localhost:9600", "azure_openai": "https://api.openai.com/v1", - "openai": "https://api.openai.com/v1", + "openai": os.getenv("DEEPSEEK_ENDPOINT"), } return default_hosts.get( binding_type, "http://localhost:11434" @@ -314,44 +329,12 @@ def create_app(args): # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=lollms_model_complete - if args.llm_binding == "lollms" - else ollama_model_complete - if args.llm_binding == "ollama" - else azure_openai_complete_if_cache - if args.llm_binding == "azure_openai" - else openai_complete_if_cache, - llm_model_name=args.llm_model, - llm_model_max_async=args.max_async, - llm_model_max_token_size=args.max_tokens, - llm_model_kwargs={ - "host": args.llm_binding_host, - "timeout": args.timeout, - "options": {"num_ctx": args.max_tokens}, - }, + llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( - embedding_dim=args.embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) - if args.llm_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - ) - if args.llm_binding == "ollama" - else azure_openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai - ) - if args.llm_binding == "azure_openai" - else openai_embedding( - texts, - model=args.embedding_model, # no host is used for openai + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" ), ), ) diff --git a/lightrag/api/start.sh b/lightrag/api/start.sh new file mode 100755 index 00000000..3e96199e --- /dev/null +++ b/lightrag/api/start.sh @@ -0,0 +1 @@ +python lightrag_server.py --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024 From da915103637f1ce1bafbf642fe8f3412811876ac Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 13:14:09 +0800 Subject: [PATCH 07/42] =?UTF-8?q?=E4=BF=AE=E6=94=B9server=E5=90=AF?= =?UTF-8?q?=E5=8A=A8=E5=91=BD=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/start.sh | 1 - start-server.sh | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100755 lightrag/api/start.sh create mode 100755 start-server.sh diff --git a/lightrag/api/start.sh b/lightrag/api/start.sh deleted file mode 100755 index 3e96199e..00000000 --- a/lightrag/api/start.sh +++ /dev/null @@ -1 +0,0 @@ -python lightrag_server.py --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024 diff --git a/start-server.sh b/start-server.sh new file mode 100755 index 00000000..2f712098 --- /dev/null +++ b/start-server.sh @@ -0,0 +1 @@ +lightrag-server --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024 From 9bfba88600ec1e77abefe872ee12665784f6c6f2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 13:32:06 +0800 Subject: [PATCH 08/42] =?UTF-8?q?=E5=87=86=E5=A4=87=E5=A2=9E=E5=8A=A0ollam?= =?UTF-8?q?a=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama_server.py | 603 +++++++++++++++++++++++++ 1 file changed, 603 insertions(+) create mode 100644 lightrag/api/lightrag_ollama_server.py diff --git a/lightrag/api/lightrag_ollama_server.py b/lightrag/api/lightrag_ollama_server.py new file mode 100644 index 00000000..42ae68f4 --- /dev/null +++ b/lightrag/api/lightrag_ollama_server.py @@ -0,0 +1,603 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Form +from pydantic import BaseModel +import logging +import argparse +from lightrag import LightRAG, QueryParam +# from lightrag.llm import lollms_model_complete, lollms_embed +# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding +from lightrag.llm import openai_complete_if_cache, ollama_embedding +# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding + +from lightrag.utils import EmbeddingFunc +from typing import Optional, List +from enum import Enum +from pathlib import Path +import shutil +import aiofiles +from ascii_colors import trace_exception +import os + +from fastapi import Depends, Security +from fastapi.security import APIKeyHeader +from fastapi.middleware.cors import CORSMiddleware + +from starlette.status import HTTP_403_FORBIDDEN + +from dotenv import load_dotenv +load_dotenv() + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + "deepseek-chat", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("DEEPSEEK_API_KEY"), + base_url=os.getenv("DEEPSEEK_ENDPOINT"), + **kwargs, + ) + +def get_default_host(binding_type: str) -> str: + default_hosts = { + "ollama": "http://m4.lan.znipower.com:11434", + "lollms": "http://localhost:9600", + "azure_openai": "https://api.openai.com/v1", + "openai": os.getenv("DEEPSEEK_ENDPOINT"), + } + return default_hosts.get( + binding_type, "http://localhost:11434" + ) # fallback to ollama if unknown + + +def parse_args(): + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) + + # Start by the bindings + parser.add_argument( + "--llm-binding", + default="ollama", + help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + parser.add_argument( + "--embedding-binding", + default="ollama", + help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + + # Parse just these arguments first + temp_args, _ = parser.parse_known_args() + + # Add remaining arguments with dynamic defaults for hosts + # Server configuration + parser.add_argument( + "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=9621, help="Server port (default: 9621)" + ) + + # Directory configuration + parser.add_argument( + "--working-dir", + default="./rag_storage", + help="Working directory for RAG storage (default: ./rag_storage)", + ) + parser.add_argument( + "--input-dir", + default="./inputs", + help="Directory containing input documents (default: ./inputs)", + ) + + # LLM Model configuration + default_llm_host = get_default_host(temp_args.llm_binding) + parser.add_argument( + "--llm-binding-host", + default=default_llm_host, + help=f"llm server host URL (default: {default_llm_host})", + ) + + parser.add_argument( + "--llm-model", + default="mistral-nemo:latest", + help="LLM model name (default: mistral-nemo:latest)", + ) + + # Embedding model configuration + default_embedding_host = get_default_host(temp_args.embedding_binding) + parser.add_argument( + "--embedding-binding-host", + default=default_embedding_host, + help=f"embedding server host URL (default: {default_embedding_host})", + ) + + parser.add_argument( + "--embedding-model", + default="bge-m3:latest", + help="Embedding model name (default: bge-m3:latest)", + ) + + def timeout_type(value): + if value is None or value == "None": + return None + return int(value) + + parser.add_argument( + "--timeout", + default=None, + type=timeout_type, + help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", + ) + # RAG configuration + parser.add_argument( + "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=32768, + help="Maximum token size (default: 32768)", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=1024, + help="Embedding dimensions (default: 1024)", + ) + parser.add_argument( + "--max-embed-tokens", + type=int, + default=8192, + help="Maximum embedding token size (default: 8192)", + ) + + # Logging configuration + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: INFO)", + ) + + parser.add_argument( + "--key", + type=str, + help="API key for authentication. This protects lightrag server against unauthorized access", + default=None, + ) + + # Optional https parameters + parser.add_argument( + "--ssl", action="store_true", help="Enable HTTPS (default: False)" + ) + parser.add_argument( + "--ssl-certfile", + default=None, + help="Path to SSL certificate file (required if --ssl is enabled)", + ) + parser.add_argument( + "--ssl-keyfile", + default=None, + help="Path to SSL private key file (required if --ssl is enabled)", + ) + return parser.parse_args() + + +class DocumentManager: + """Handles document operations and tracking""" + + def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): + self.input_dir = Path(input_dir) + self.supported_extensions = supported_extensions + self.indexed_files = set() + + # Create input directory if it doesn't exist + self.input_dir.mkdir(parents=True, exist_ok=True) + + def scan_directory(self) -> List[Path]: + """Scan input directory for new files""" + new_files = [] + for ext in self.supported_extensions: + for file_path in self.input_dir.rglob(f"*{ext}"): + if file_path not in self.indexed_files: + new_files.append(file_path) + return new_files + + def mark_as_indexed(self, file_path: Path): + """Mark a file as indexed""" + self.indexed_files.add(file_path) + + def is_supported_file(self, filename: str) -> bool: + """Check if file type is supported""" + return any(filename.lower().endswith(ext) for ext in self.supported_extensions) + + +# Pydantic models +class SearchMode(str, Enum): + naive = "naive" + local = "local" + global_ = "global" + hybrid = "hybrid" + + +class QueryRequest(BaseModel): + query: str + mode: SearchMode = SearchMode.hybrid + stream: bool = False + only_need_context: bool = False + + +class QueryResponse(BaseModel): + response: str + + +class InsertTextRequest(BaseModel): + text: str + description: Optional[str] = None + + +class InsertResponse(BaseModel): + status: str + message: str + document_count: int + + +def get_api_key_dependency(api_key: Optional[str]): + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth + + +def create_app(args): + # Verify that bindings arer correctly setup + if args.llm_binding not in ["lollms", "ollama", "openai"]: + raise Exception("llm binding not supported") + + if args.embedding_binding not in ["lollms", "ollama", "openai"]: + raise Exception("embedding binding not supported") + + # Add SSL validation + if args.ssl: + if not args.ssl_certfile or not args.ssl_keyfile: + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) + if not os.path.exists(args.ssl_certfile): + raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") + if not os.path.exists(args.ssl_keyfile): + raise Exception(f"SSL key file not found: {args.ssl_keyfile}") + + # Setup logging + logging.basicConfig( + format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) + ) + + # Check if API key is provided either through env var or args + api_key = os.getenv("LIGHTRAG_API_KEY") or args.key + + # Initialize FastAPI + app = FastAPI( + title="LightRAG API", + description="API for querying text using LightRAG with separate storage and input directories" + + "(With authentication)" + if api_key + else "", + version="1.0.1", + openapi_tags=[{"name": "api"}], + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Create the optional API key dependency + optional_api_key = get_api_key_dependency(api_key) + + # Create working directory if it doesn't exist + Path(args.working_dir).mkdir(parents=True, exist_ok=True) + + # Initialize document manager + doc_manager = DocumentManager(args.input_dir) + + # Initialize RAG + rag = LightRAG( + working_dir=args.working_dir, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" + ), + ), + ) + + @app.on_event("startup") + async def startup_event(): + """Index all files in input directory during startup""" + try: + new_files = doc_manager.scan_directory() + for file_path in new_files: + try: + # Use async file reading + async with aiofiles.open(file_path, "r", encoding="utf-8") as f: + content = await f.read() + # Use the async version of insert directly + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + logging.info(f"Indexed file: {file_path}") + except Exception as e: + trace_exception(e) + logging.error(f"Error indexing file {file_path}: {str(e)}") + + logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") + + except Exception as e: + logging.error(f"Error during startup indexing: {str(e)}") + + @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) + async def scan_for_new_documents(): + """Manually trigger scanning for new documents""" + try: + new_files = doc_manager.scan_directory() + indexed_count = 0 + + for file_path in new_files: + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + indexed_count += 1 + except Exception as e: + logging.error(f"Error indexing file {file_path}: {str(e)}") + + return { + "status": "success", + "indexed_count": indexed_count, + "total_documents": len(doc_manager.indexed_files), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) + async def upload_to_input_dir(file: UploadFile = File(...)): + """Upload a file to the input directory""" + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + file_path = doc_manager.input_dir / file.filename + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + # Immediately index the uploaded file + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + + return { + "status": "success", + "message": f"File uploaded and indexed: {file.filename}", + "total_documents": len(doc_manager.indexed_files), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) + async def query_text(request: QueryRequest): + try: + response = await rag.aquery( + request.query, + param=QueryParam( + mode=request.mode, + stream=request.stream, + only_need_context=request.only_need_context, + ), + ) + + if request.stream: + result = "" + async for chunk in response: + result += chunk + return QueryResponse(response=result) + else: + return QueryResponse(response=response) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) + async def query_text_stream(request: QueryRequest): + try: + response = rag.query( + request.query, + param=QueryParam( + mode=request.mode, + stream=True, + only_need_context=request.only_need_context, + ), + ) + + async def stream_generator(): + async for chunk in response: + yield chunk + + return stream_generator() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_text(request: InsertTextRequest): + try: + await rag.ainsert(request.text) + return InsertResponse( + status="success", + message="Text successfully inserted", + document_count=1, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_file(file: UploadFile = File(...), description: str = Form(None)): + try: + content = await file.read() + + if file.filename.endswith((".txt", ".md")): + text = content.decode("utf-8") + await rag.ainsert(text) + else: + raise HTTPException( + status_code=400, + detail="Unsupported file type. Only .txt and .md files are supported", + ) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' successfully inserted", + document_count=1, + ) + except UnicodeDecodeError: + raise HTTPException(status_code=400, detail="File encoding not supported") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_batch(files: List[UploadFile] = File(...)): + try: + inserted_count = 0 + failed_files = [] + + for file in files: + try: + content = await file.read() + if file.filename.endswith((".txt", ".md")): + text = content.decode("utf-8") + await rag.ainsert(text) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + except Exception as e: + failed_files.append(f"{file.filename} ({str(e)})") + + status_message = f"Successfully inserted {inserted_count} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse( + status="success" if inserted_count > 0 else "partial_success", + message=status_message, + document_count=len(files), + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def clear_documents(): + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse( + status="success", + message="All documents cleared successfully", + document_count=0, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.get("/health", dependencies=[Depends(optional_api_key)]) + async def get_status(): + """Get current system status""" + return { + "status": "healthy", + "working_directory": str(args.working_dir), + "input_directory": str(args.input_dir), + "indexed_files": len(doc_manager.indexed_files), + "configuration": { + # LLM configuration binding/host address (if applicable)/model (if applicable) + "llm_binding": args.llm_binding, + "llm_binding_host": args.llm_binding_host, + "llm_model": args.llm_model, + # embedding model configuration binding/host address (if applicable)/model (if applicable) + "embedding_binding": args.embedding_binding, + "embedding_binding_host": args.embedding_binding_host, + "embedding_model": args.embedding_model, + "max_tokens": args.max_tokens, + }, + } + + return app + + +def main(): + args = parse_args() + import uvicorn + + app = create_app(args) + uvicorn_config = { + "app": app, + "host": args.host, + "port": args.port, + } + if args.ssl: + uvicorn_config.update( + { + "ssl_certfile": args.ssl_certfile, + "ssl_keyfile": args.ssl_keyfile, + } + ) + uvicorn.run(**uvicorn_config) + + +if __name__ == "__main__": + main() From b97d1ecd72f024738e8ea49226f4a2bb72630d57 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 13:35:20 +0800 Subject: [PATCH 09/42] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/{lightrag_ollama_server.py => lightrag_ollama.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename lightrag/api/{lightrag_ollama_server.py => lightrag_ollama.py} (100%) diff --git a/lightrag/api/lightrag_ollama_server.py b/lightrag/api/lightrag_ollama.py similarity index 100% rename from lightrag/api/lightrag_ollama_server.py rename to lightrag/api/lightrag_ollama.py From be134878fe505e540904d62c3989abb72aaf2258 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 14:31:49 +0800 Subject: [PATCH 10/42] =?UTF-8?q?=E5=AE=8C=E6=88=90ollma=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E7=9A=84=E4=BB=A3=E7=A0=81=E7=BC=96=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 135 ++++++++++++++++++++++++++++++-- setup.py | 1 + 2 files changed, 129 insertions(+), 7 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 42ae68f4..6f1ec9a4 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form from pydantic import BaseModel import logging import argparse +from typing import List, Dict, Any, Optional from lightrag import LightRAG, QueryParam -# from lightrag.llm import lollms_model_complete, lollms_embed -# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding from lightrag.llm import openai_complete_if_cache, ollama_embedding -# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding from lightrag.utils import EmbeddingFunc -from typing import Optional, List from enum import Enum from pathlib import Path import shutil @@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN from dotenv import load_dotenv load_dotenv() +# Constants for model information +LIGHTRAG_NAME = "lightrag" +LIGHTRAG_TAG = "latest" +LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" +LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" +LIGHTRAG_DIGEST = "sha256:lightrag" + async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: @@ -219,21 +223,43 @@ class DocumentManager: class SearchMode(str, Enum): naive = "naive" local = "local" - global_ = "global" + global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global" hybrid = "hybrid" +# Ollama API compatible models +class OllamaMessage(BaseModel): + role: str + content: str +class OllamaChatRequest(BaseModel): + model: str = LIGHTRAG_MODEL + messages: List[OllamaMessage] + stream: bool = False + options: Optional[Dict[str, Any]] = None + +class OllamaChatResponse(BaseModel): + model: str + created_at: str + message: OllamaMessage + done: bool + +class OllamaVersionResponse(BaseModel): + version: str + build: str = "default" + +class OllamaTagResponse(BaseModel): + models: List[Dict[str, str]] + +# Original LightRAG models class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid stream: bool = False only_need_context: bool = False - class QueryResponse(BaseModel): response: str - class InsertTextRequest(BaseModel): text: str description: Optional[str] = None @@ -555,6 +581,101 @@ def create_app(args): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + # Ollama compatible API endpoints + @app.get("/api/version") + async def get_version(): + """Get Ollama version information""" + return OllamaVersionResponse( + version="0.1.0" + ) + + @app.get("/api/tags") + async def get_tags(): + """Get available models""" + return OllamaTagResponse( + models=[{ + "name": LIGHTRAG_NAME, + "tag": LIGHTRAG_TAG, + "size": 0, + "digest": LIGHTRAG_DIGEST, + "modified_at": LIGHTRAG_CREATED_AT + }] + ) + + def parse_query_mode(query: str) -> tuple[str, SearchMode]: + """Parse query prefix to determine search mode + Returns tuple of (cleaned_query, search_mode) + """ + mode_map = { + "/local ": SearchMode.local, + "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword + "/naive ": SearchMode.naive, + "/hybrid ": SearchMode.hybrid + } + + for prefix, mode in mode_map.items(): + if query.startswith(prefix): + return query[len(prefix):], mode + + return query, SearchMode.hybrid + + @app.post("/api/chat") + async def chat(request: OllamaChatRequest): + """Handle chat completion requests""" + try: + # Convert chat format to query + query = request.messages[-1].content if request.messages else "" + + # Parse query mode and clean query + cleaned_query, mode = parse_query_mode(query) + + # Call RAG with determined mode + response = await rag.aquery( + cleaned_query, + param=QueryParam( + mode=mode, + stream=request.stream + ) + ) + + if request.stream: + async def stream_generator(): + result = "" + async for chunk in response: + result += chunk + yield OllamaChatResponse( + model=LIGHTRAG_MODEL, + created_at=LIGHTRAG_CREATED_AT, + message=OllamaMessage( + role="assistant", + content=chunk + ), + done=False + ) + # Send final message + yield OllamaChatResponse( + model=LIGHTRAG_MODEL, + created_at=LIGHTRAG_CREATED_AT, + message=OllamaMessage( + role="assistant", + content=result + ), + done=True + ) + return stream_generator() + else: + return OllamaChatResponse( + model=LIGHTRAG_MODEL, + created_at=LIGHTRAG_CREATED_AT, + message=OllamaMessage( + role="assistant", + content=response + ), + done=True + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" diff --git a/setup.py b/setup.py index 38eff646..b5850d26 100644 --- a/setup.py +++ b/setup.py @@ -101,6 +101,7 @@ setuptools.setup( entry_points={ "console_scripts": [ "lightrag-server=lightrag.api.lightrag_server:main [api]", + "lightrag-ollama=lightrag.api.lightrag_ollama:main [api]", ], }, ) From c1f4f4a20e1e5474fb56bb1443a0664881fda995 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 15:06:28 +0800 Subject: [PATCH 11/42] =?UTF-8?q?=E4=BC=98=E5=8C=96ollama=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E6=95=B0=E6=8D=AE=E7=9A=84=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 6f1ec9a4..e03c91bd 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -27,6 +27,7 @@ load_dotenv() LIGHTRAG_NAME = "lightrag" LIGHTRAG_TAG = "latest" LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" +LIGHTRAG_SIZE = 7365960935 LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" LIGHTRAG_DIGEST = "sha256:lightrag" @@ -245,7 +246,6 @@ class OllamaChatResponse(BaseModel): class OllamaVersionResponse(BaseModel): version: str - build: str = "default" class OllamaTagResponse(BaseModel): models: List[Dict[str, str]] @@ -586,7 +586,7 @@ def create_app(args): async def get_version(): """Get Ollama version information""" return OllamaVersionResponse( - version="0.1.0" + version="0.5.4" ) @app.get("/api/tags") @@ -595,10 +595,19 @@ def create_app(args): return OllamaTagResponse( models=[{ "name": LIGHTRAG_NAME, + "model": LIGHTRAG_NAME, "tag": LIGHTRAG_TAG, - "size": 0, + "size": LIGHTRAG_SIZE, "digest": LIGHTRAG_DIGEST, - "modified_at": LIGHTRAG_CREATED_AT + "modified_at": LIGHTRAG_CREATED_AT, + "details": { + "parent_model": "", + "format": "gguf", + "family": LIGHTRAG_NAME, + "families": [LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0" + } }] ) From fd50c3a240dddc2cd6d11adcf69a40094ea61a47 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 17:43:00 +0800 Subject: [PATCH 12/42] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=81=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BF=AE=E6=94=B9=20/api/t?= =?UTF-8?q?ags=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 67 +++++++++++++++++++++++++-------- start-server.sh | 2 +- 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index e03c91bd..1d8bc86e 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -26,7 +26,7 @@ load_dotenv() # Constants for model information LIGHTRAG_NAME = "lightrag" LIGHTRAG_TAG = "latest" -LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" +LIGHTRAG_MODEL = "lightrag:latest" LIGHTRAG_SIZE = 7365960935 LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" LIGHTRAG_DIGEST = "sha256:lightrag" @@ -247,8 +247,25 @@ class OllamaChatResponse(BaseModel): class OllamaVersionResponse(BaseModel): version: str +class OllamaModelDetails(BaseModel): + parent_model: str + format: str + family: str + families: List[str] + parameter_size: str + quantization_level: str + +class OllamaModel(BaseModel): + name: str + model: str + tag: str + size: int + digest: str + modified_at: str + details: OllamaModelDetails + class OllamaTagResponse(BaseModel): - models: List[Dict[str, str]] + models: List[OllamaModel] # Original LightRAG models class QueryRequest(BaseModel): @@ -632,26 +649,46 @@ def create_app(args): async def chat(request: OllamaChatRequest): """Handle chat completion requests""" try: - # Convert chat format to query - query = request.messages[-1].content if request.messages else "" + # 获取所有消息内容 + messages = request.messages + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") - # Parse query mode and clean query + # 获取最后一条消息作为查询 + query = messages[-1].content + + # 解析查询模式 cleaned_query, mode = parse_query_mode(query) - # Call RAG with determined mode - response = await rag.aquery( - cleaned_query, - param=QueryParam( + # 构建系统提示词(如果有历史消息) + system_prompt = None + history_messages = [] + if len(messages) > 1: + # 如果第一条消息是系统消息,提取为system_prompt + if messages[0].role == "system": + system_prompt = messages[0].content + messages = messages[1:] + + # 收集历史消息(除了最后一条) + history_messages = [(msg.role, msg.content) for msg in messages[:-1]] + + # 调用RAG进行查询 + kwargs = { + "param": QueryParam( mode=mode, - stream=request.stream + stream=request.stream, ) - ) + } + if system_prompt is not None: + kwargs["system_prompt"] = system_prompt + if history_messages: + kwargs["history_messages"] = history_messages + + response = await rag.aquery(cleaned_query, **kwargs) if request.stream: async def stream_generator(): - result = "" async for chunk in response: - result += chunk yield OllamaChatResponse( model=LIGHTRAG_MODEL, created_at=LIGHTRAG_CREATED_AT, @@ -661,13 +698,13 @@ def create_app(args): ), done=False ) - # Send final message + # 发送一个空的完成消息 yield OllamaChatResponse( model=LIGHTRAG_MODEL, created_at=LIGHTRAG_CREATED_AT, message=OllamaMessage( role="assistant", - content=result + content="" ), done=True ) diff --git a/start-server.sh b/start-server.sh index 2f712098..a250a3c3 100755 --- a/start-server.sh +++ b/start-server.sh @@ -1 +1 @@ -lightrag-server --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024 +lightrag-ollama --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024 From 882da8860353534843ea20f235d1ece1abe6b036 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 18:19:39 +0800 Subject: [PATCH 13/42] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=90=8D=E7=A7=B0=E8=BF=94=E5=9B=9E=E9=94=99=E8=AF=AF=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 1d8bc86e..3c7da1ea 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -258,7 +258,6 @@ class OllamaModelDetails(BaseModel): class OllamaModel(BaseModel): name: str model: str - tag: str size: int digest: str modified_at: str @@ -611,9 +610,8 @@ def create_app(args): """Get available models""" return OllamaTagResponse( models=[{ - "name": LIGHTRAG_NAME, - "model": LIGHTRAG_NAME, - "tag": LIGHTRAG_TAG, + "name": LIGHTRAG_MODEL, + "model": LIGHTRAG_MODEL, "size": LIGHTRAG_SIZE, "digest": LIGHTRAG_DIGEST, "modified_at": LIGHTRAG_CREATED_AT, From 4e5517a602d1a23bcf3797f1e7793401167a0976 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 18:27:35 +0800 Subject: [PATCH 14/42] =?UTF-8?q?=E4=BF=AE=E5=A4=8Drag=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E4=B8=8D=E6=AD=A3=E7=A1=AE=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 3c7da1ea..39c8256a 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -658,31 +658,14 @@ def create_app(args): # 解析查询模式 cleaned_query, mode = parse_query_mode(query) - # 构建系统提示词(如果有历史消息) - system_prompt = None - history_messages = [] - if len(messages) > 1: - # 如果第一条消息是系统消息,提取为system_prompt - if messages[0].role == "system": - system_prompt = messages[0].content - messages = messages[1:] - - # 收集历史消息(除了最后一条) - history_messages = [(msg.role, msg.content) for msg in messages[:-1]] - # 调用RAG进行查询 - kwargs = { - "param": QueryParam( + response = await rag.aquery( + cleaned_query, + param=QueryParam( mode=mode, stream=request.stream, ) - } - if system_prompt is not None: - kwargs["system_prompt"] = system_prompt - if history_messages: - kwargs["history_messages"] = history_messages - - response = await rag.aquery(cleaned_query, **kwargs) + ) if request.stream: async def stream_generator(): From 828af49d6bacf5d30d7978e07a952244060f9adf Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 18:47:01 +0800 Subject: [PATCH 15/42] =?UTF-8?q?=E8=83=BD=E5=A4=9F=E6=AD=A3=E7=A1=AE?= =?UTF-8?q?=E8=B0=83=E7=94=A8rag=EF=BC=8Crag=E6=89=A7=E8=A1=8C=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E5=90=8E=EF=BC=8C=E6=97=A0=E6=B3=95=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 75 ++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 39c8256a..4e83acb0 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -659,38 +659,55 @@ def create_app(args): cleaned_query, mode = parse_query_mode(query) # 调用RAG进行查询 - response = await rag.aquery( - cleaned_query, - param=QueryParam( - mode=mode, - stream=request.stream, - ) - ) - if request.stream: + response = await rag.aquery( + cleaned_query, + param=QueryParam( + mode=mode, + stream=True, + only_need_context=False + ), + ) + async def stream_generator(): - async for chunk in response: - yield OllamaChatResponse( - model=LIGHTRAG_MODEL, - created_at=LIGHTRAG_CREATED_AT, - message=OllamaMessage( - role="assistant", - content=chunk - ), - done=False - ) - # 发送一个空的完成消息 - yield OllamaChatResponse( - model=LIGHTRAG_MODEL, - created_at=LIGHTRAG_CREATED_AT, - message=OllamaMessage( - role="assistant", - content="" - ), - done=True - ) - return stream_generator() + try: + async for chunk in response: + yield { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk + }, + "done": False + } + yield { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "" + }, + "done": True + } + except Exception as e: + logging.error(f"Error in stream_generator: {str(e)}") + raise + from fastapi.responses import StreamingResponse + import json + return StreamingResponse( + (f"data: {json.dumps(chunk)}\n\n" async for chunk in stream_generator()), + media_type="text/event-stream" + ) else: + response = await rag.aquery( + cleaned_query, + param=QueryParam( + mode=mode, + stream=False, + only_need_context=False + ), + ) return OllamaChatResponse( model=LIGHTRAG_MODEL, created_at=LIGHTRAG_CREATED_AT, From f15f97a51d8a778d5e4116abc38b15707ee355c6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 19:32:03 +0800 Subject: [PATCH 16/42] =?UTF-8?q?=E4=B8=B4=E6=97=B6=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 124 +++++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 27 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 4e83acb0..cc549f4b 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -472,10 +472,25 @@ def create_app(args): ) if request.stream: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) + from fastapi.responses import StreamingResponse + import json + + async def stream_generator(): + async for chunk in response: + yield f"data: {json.dumps({'response': chunk})}\n\n" + + return StreamingResponse( + stream_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type" + } + ) else: return QueryResponse(response=response) except Exception as e: @@ -484,7 +499,7 @@ def create_app(args): @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: - response = rag.query( + response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await request.query, param=QueryParam( mode=request.mode, @@ -493,11 +508,24 @@ def create_app(args): ), ) + from fastapi.responses import StreamingResponse + async def stream_generator(): async for chunk in response: - yield chunk + yield f"data: {chunk}\n\n" - return stream_generator() + return StreamingResponse( + stream_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type" + } + ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -659,20 +687,48 @@ def create_app(args): cleaned_query, mode = parse_query_mode(query) # 调用RAG进行查询 + query_param = QueryParam( + mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid + stream=request.stream, + only_need_context=False + ) + if request.stream: - response = await rag.aquery( + from fastapi.responses import StreamingResponse + import json + + response = await rag.aquery( # 需要 await 来获取异步生成器 cleaned_query, - param=QueryParam( - mode=mode, - stream=True, - only_need_context=False - ), + param=query_param ) async def stream_generator(): try: - async for chunk in response: - yield { + # 确保 response 是异步生成器 + if isinstance(response, str): + data = { + 'model': LIGHTRAG_MODEL, + 'created_at': LIGHTRAG_CREATED_AT, + 'message': { + 'role': 'assistant', + 'content': response + }, + 'done': True + } + yield f"data: {json.dumps(data)}\n\n" + else: + async for chunk in response: + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk + }, + "done": False + } + yield f"data: {json.dumps(data)}\n\n" + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { @@ -681,7 +737,10 @@ def create_app(args): }, "done": False } - yield { + yield f"data: {json.dumps(data)}\n\n" + + # 发送完成标记 + data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { @@ -690,30 +749,41 @@ def create_app(args): }, "done": True } + yield f"data: {json.dumps(data)}\n\n" except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise - from fastapi.responses import StreamingResponse - import json + return StreamingResponse( - (f"data: {json.dumps(chunk)}\n\n" async for chunk in stream_generator()), - media_type="text/event-stream" + stream_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type" + } ) else: - response = await rag.aquery( + # 非流式响应 + response_text = await rag.aquery( cleaned_query, - param=QueryParam( - mode=mode, - stream=False, - only_need_context=False - ), + param=query_param ) + + # 确保响应不为空 + if not response_text: + response_text = "No response generated" + + # 构造并返回响应 return OllamaChatResponse( model=LIGHTRAG_MODEL, created_at=LIGHTRAG_CREATED_AT, message=OllamaMessage( role="assistant", - content=response + content=str(response_text) # 确保转换为字符串 ), done=True ) From 23f838ec946d8bfea26462bee7bb84a79cdb941f Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 20:18:17 +0800 Subject: [PATCH 17/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E5=93=8D=E5=BA=94=E5=A4=84=E7=90=86=E5=B9=B6=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复流式响应中的完成标记逻辑 - 添加非流式调用测试 - 添加流式调用测试 - 优化JSON序列化,支持非ASCII字符 - 确保生成器在完成标记后立即结束 --- lightrag/api/lightrag_ollama.py | 66 ++++++++++++++--------------- test_lightrag_ollama_chat.py | 73 +++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 36 deletions(-) create mode 100644 test_lightrag_ollama_chat.py diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index cc549f4b..004c2739 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -706,50 +706,44 @@ def create_app(args): try: # 确保 response 是异步生成器 if isinstance(response, str): - data = { - 'model': LIGHTRAG_MODEL, - 'created_at': LIGHTRAG_CREATED_AT, - 'message': { - 'role': 'assistant', - 'content': response - }, - 'done': True - } - yield f"data: {json.dumps(data)}\n\n" - else: - async for chunk in response: - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk - }, - "done": False - } - yield f"data: {json.dumps(data)}\n\n" + # 如果是字符串,作为单个完整响应发送 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": chunk + "content": response }, - "done": False + "done": True } - yield f"data: {json.dumps(data)}\n\n" + yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + else: + # 流式响应 + async for chunk in response: + if chunk: # 只发送非空内容 + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk + }, + "done": False + } + yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" - # 发送完成标记 - data = { - "model": LIGHTRAG_MODEL, - "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": "" - }, - "done": True - } - yield f"data: {json.dumps(data)}\n\n" + # 发送完成标记 + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "" + }, + "done": True + } + yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + return # 确保生成器在发送完成标记后立即结束 except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py new file mode 100644 index 00000000..067b8877 --- /dev/null +++ b/test_lightrag_ollama_chat.py @@ -0,0 +1,73 @@ +import requests +import json +import sseclient + +def test_non_stream_chat(): + """测试非流式调用 /api/chat 接口""" + url = "http://localhost:9621/api/chat" + + # 构造请求数据 + data = { + "model": "lightrag:latest", + "messages": [ + { + "role": "user", + "content": "孙悟空" + } + ], + "stream": False + } + + # 发送请求 + response = requests.post(url, json=data) + + # 打印响应 + print("\n=== 非流式调用响应 ===") + print(json.dumps(response.json(), ensure_ascii=False, indent=2)) + +def test_stream_chat(): + """测试流式调用 /api/chat 接口""" + url = "http://localhost:9621/api/chat" + + # 构造请求数据 + data = { + "model": "lightrag:latest", + "messages": [ + { + "role": "user", + "content": "/naive 孙悟空有什么法力,性格特征是什么" + } + ], + "stream": True + } + + # 发送请求并获取 SSE 流 + response = requests.post(url, json=data, stream=True) + client = sseclient.SSEClient(response) + + print("\n=== 流式调用响应 ===") + output_buffer = [] + try: + for event in client.events(): + try: + data = json.loads(event.data) + message = data.get("message", {}) + content = message.get("content", "") + if content: # 只收集非空内容 + output_buffer.append(content) + if data.get("done", False): # 如果收到完成标记,退出循环 + break + except json.JSONDecodeError: + print("Error decoding JSON from SSE event") + finally: + response.close() # 确保关闭响应连接 + + # 一次性打印所有收集到的内容 + print("".join(output_buffer)) + +if __name__ == "__main__": + # 先测试非流式调用 + test_non_stream_chat() + + # 再测试流式调用 + test_stream_chat() From f81b1cdf0a65cc157e9d2402f54ae05aac432e24 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 20:46:45 +0800 Subject: [PATCH 18/42] =?UTF-8?q?=E4=B8=BAOllama=20API=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E7=BB=93=E6=9E=9C=E6=B7=BB=E5=8A=A0=E5=9B=BE=E5=83=8F=E5=AD=97?= =?UTF-8?q?=E6=AE=B5=E5=92=8C=E6=80=A7=E8=83=BD=E7=BB=9F=E8=AE=A1=E4=BF=A1?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在OllamaMessage中添加images字段 - 响应消息中增加images字段 - 完成标记中添加性能统计信息 - 更新测试用例以处理性能统计 - 移除测试用例中的/naive前缀 --- lightrag/api/lightrag_ollama.py | 24 +++++++++++++++--------- test_lightrag_ollama_chat.py | 18 +++++++++++------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 004c2739..3b92902f 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -231,6 +231,7 @@ class SearchMode(str, Enum): class OllamaMessage(BaseModel): role: str content: str + images: Optional[List[str]] = None class OllamaChatRequest(BaseModel): model: str = LIGHTRAG_MODEL @@ -712,7 +713,8 @@ def create_app(args): "created_at": LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": response + "content": response, + "images": None }, "done": True } @@ -726,21 +728,24 @@ def create_app(args): "created_at": LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": chunk + "content": chunk, + "images": None }, "done": False } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" - # 发送完成标记 + # 发送完成标记,包含性能统计信息 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": "" - }, - "done": True + "done": True, + "total_duration": 0, # 由于我们没有实际统计这些指标,暂时使用默认值 + "load_duration": 0, + "prompt_eval_count": 0, + "prompt_eval_duration": 0, + "eval_count": 0, + "eval_duration": 0 } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" return # 确保生成器在发送完成标记后立即结束 @@ -777,7 +782,8 @@ def create_app(args): created_at=LIGHTRAG_CREATED_AT, message=OllamaMessage( role="assistant", - content=str(response_text) # 确保转换为字符串 + content=str(response_text), # 确保转换为字符串 + images=None ), done=True ) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 067b8877..b941ee27 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -35,7 +35,7 @@ def test_stream_chat(): "messages": [ { "role": "user", - "content": "/naive 孙悟空有什么法力,性格特征是什么" + "content": "孙悟空有什么法力,性格特征是什么" } ], "stream": True @@ -51,12 +51,16 @@ def test_stream_chat(): for event in client.events(): try: data = json.loads(event.data) - message = data.get("message", {}) - content = message.get("content", "") - if content: # 只收集非空内容 - output_buffer.append(content) - if data.get("done", False): # 如果收到完成标记,退出循环 - break + if data.get("done", False): # 如果是完成标记 + if "total_duration" in data: # 最终的性能统计消息 + print("\n=== 性能统计 ===") + print(json.dumps(data, ensure_ascii=False, indent=2)) + break + else: # 正常的内容消息 + message = data.get("message", {}) + content = message.get("content", "") + if content: # 只收集非空内容 + output_buffer.append(content) except json.JSONDecodeError: print("Error decoding JSON from SSE event") finally: From 8ef1248c761a2725014d889099e799bf194b532c Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 20:54:22 +0800 Subject: [PATCH 19/42] =?UTF-8?q?=E5=B0=86OllamaChatRequest=E7=9A=84stream?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E9=BB=98=E8=AE=A4=E5=80=BC=E6=94=B9=E4=B8=BA?= =?UTF-8?q?True?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 3b92902f..54581a6f 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -236,7 +236,7 @@ class OllamaMessage(BaseModel): class OllamaChatRequest(BaseModel): model: str = LIGHTRAG_MODEL messages: List[OllamaMessage] - stream: bool = False + stream: bool = True # 默认为流式模式 options: Optional[Dict[str, Any]] = None class OllamaChatResponse(BaseModel): From af9ac188f01403383de8ee3f631a9d7ab5c89690 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 21:15:12 +0800 Subject: [PATCH 20/42] =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E7=9A=84=E8=B0=83=E8=AF=95=E5=92=8C=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E7=BB=9F=E8=AE=A1=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加原始请求日志记录 - 修改响应结构以包含性能统计 - 更新测试用例以展示性能数据 - 优化响应格式为字典结构 - 增加请求体解码功能 --- lightrag/api/lightrag_ollama.py | 35 +++++++++++++++++++++------------ test_lightrag_ollama_chat.py | 21 +++++++++++++++++++- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 54581a6f..959506d5 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form +from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request from pydantic import BaseModel import logging import argparse @@ -673,7 +673,10 @@ def create_app(args): return query, SearchMode.hybrid @app.post("/api/chat") - async def chat(request: OllamaChatRequest): + async def chat(raw_request: Request, request: OllamaChatRequest): + # 打印原始请求数据 + body = await raw_request.body() + logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}") """Handle chat completion requests""" try: # 获取所有消息内容 @@ -776,17 +779,23 @@ def create_app(args): if not response_text: response_text = "No response generated" - # 构造并返回响应 - return OllamaChatResponse( - model=LIGHTRAG_MODEL, - created_at=LIGHTRAG_CREATED_AT, - message=OllamaMessage( - role="assistant", - content=str(response_text), # 确保转换为字符串 - images=None - ), - done=True - ) + # 构造响应,包含性能统计信息 + return { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": str(response_text), # 确保转换为字符串 + "images": None + }, + "done": True, + "total_duration": 0, # 由于我们没有实际统计这些指标,暂时使用默认值 + "load_duration": 0, + "prompt_eval_count": 0, + "prompt_eval_duration": 0, + "eval_count": 0, + "eval_duration": 0 + } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index b941ee27..60158ac2 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -23,7 +23,26 @@ def test_non_stream_chat(): # 打印响应 print("\n=== 非流式调用响应 ===") - print(json.dumps(response.json(), ensure_ascii=False, indent=2)) + response_json = response.json() + + # 打印消息内容 + print("=== 响应内容 ===") + print(json.dumps({ + "model": response_json["model"], + "message": response_json["message"] + }, ensure_ascii=False, indent=2)) + + # 打印性能统计 + print("\n=== 性能统计 ===") + stats = { + "total_duration": response_json["total_duration"], + "load_duration": response_json["load_duration"], + "prompt_eval_count": response_json["prompt_eval_count"], + "prompt_eval_duration": response_json["prompt_eval_duration"], + "eval_count": response_json["eval_count"], + "eval_duration": response_json["eval_duration"] + } + print(json.dumps(stats, ensure_ascii=False, indent=2)) def test_stream_chat(): """测试流式调用 /api/chat 接口""" From 6d44178f63418eb8b992b5b40b369596e08f27e1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 21:26:20 +0800 Subject: [PATCH 21/42] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=E6=B5=81=E7=BB=93=E6=9D=9F=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 24 ++++++++++++------------ test_lightrag_ollama_chat.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 959506d5..5a066e15 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -743,12 +743,12 @@ def create_app(args): "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "done": True, - "total_duration": 0, # 由于我们没有实际统计这些指标,暂时使用默认值 - "load_duration": 0, - "prompt_eval_count": 0, - "prompt_eval_duration": 0, - "eval_count": 0, - "eval_duration": 0 + "total_duration": 1, # 由于我们没有实际统计这些指标,暂时使用默认值 + "load_duration": 1, + "prompt_eval_count": 999, + "prompt_eval_duration": 1, + "eval_count": 999, + "eval_duration": 1 } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" return # 确保生成器在发送完成标记后立即结束 @@ -789,12 +789,12 @@ def create_app(args): "images": None }, "done": True, - "total_duration": 0, # 由于我们没有实际统计这些指标,暂时使用默认值 - "load_duration": 0, - "prompt_eval_count": 0, - "prompt_eval_duration": 0, - "eval_count": 0, - "eval_duration": 0 + "total_duration": 1, # 由于我们没有实际统计这些指标,暂时使用默认值 + "load_duration": 1, + "prompt_eval_count": 999, + "prompt_eval_duration": 1, + "eval_count": 999, + "eval_duration": 1 } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 60158ac2..f8947585 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -70,7 +70,7 @@ def test_stream_chat(): for event in client.events(): try: data = json.loads(event.data) - if data.get("done", False): # 如果是完成标记 + if data.get("done", True): # 如果是完成标记 if "total_duration" in data: # 最终的性能统计消息 print("\n=== 性能统计 ===") print(json.dumps(data, ensure_ascii=False, indent=2)) From ca2caf47bc0ed52899c06c7507cea7a5912fb5ed Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 22:14:57 +0800 Subject: [PATCH 22/42] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E5=93=8D=E5=BA=94=E7=9A=84=E8=BE=93=E5=87=BA=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=EF=BC=9A=E4=BB=8Eevent-stream=E6=94=B9=E4=B8=BAx-ndjson?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 5a066e15..714731eb 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request from pydantic import BaseModel import logging import argparse +import json from typing import List, Dict, Any, Optional from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, ollama_embedding @@ -474,19 +475,18 @@ def create_app(args): if request.stream: from fastapi.responses import StreamingResponse - import json async def stream_generator(): async for chunk in response: - yield f"data: {json.dumps({'response': chunk})}\n\n" + yield f"{json.dumps({'response': chunk})}\n" return StreamingResponse( stream_generator(), - media_type="text/event-stream", + media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", - "Content-Type": "text/event-stream", + "Content-Type": "application/x-ndjson", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type" @@ -513,15 +513,15 @@ def create_app(args): async def stream_generator(): async for chunk in response: - yield f"data: {chunk}\n\n" + yield f"{json.dumps({'response': chunk})}\n" return StreamingResponse( stream_generator(), - media_type="text/event-stream", + media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", - "Content-Type": "text/event-stream", + "Content-Type": "application/x-ndjson", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type" @@ -699,7 +699,6 @@ def create_app(args): if request.stream: from fastapi.responses import StreamingResponse - import json response = await rag.aquery( # 需要 await 来获取异步生成器 cleaned_query, @@ -721,7 +720,7 @@ def create_app(args): }, "done": True } - yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + yield f"{json.dumps(data, ensure_ascii=False)}\n" else: # 流式响应 async for chunk in response: @@ -736,7 +735,7 @@ def create_app(args): }, "done": False } - yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + yield f"{json.dumps(data, ensure_ascii=False)}\n" # 发送完成标记,包含性能统计信息 data = { @@ -750,7 +749,7 @@ def create_app(args): "eval_count": 999, "eval_duration": 1 } - yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + yield f"{json.dumps(data, ensure_ascii=False)}\n" return # 确保生成器在发送完成标记后立即结束 except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") @@ -758,11 +757,11 @@ def create_app(args): return StreamingResponse( stream_generator(), - media_type="text/event-stream", + media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", - "Content-Type": "text/event-stream", + "Content-Type": "application/x-ndjson", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type" From f441a454537b05182baa0c8584d48940fa455d9d Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 22:15:46 +0800 Subject: [PATCH 23/42] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + test_lightrag_ollama_chat.py | 614 ++++++++++++++++++++++++++++++++--- 2 files changed, 561 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index 5749adb5..8ac420b1 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ rag_storage venv/ examples/input/ examples/output/ +test_results.json diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index f8947585..ae32a3b8 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -1,96 +1,602 @@ +""" +LightRAG Ollama 兼容接口测试脚本 + +这个脚本测试 LightRAG 的 Ollama 兼容接口,包括: +1. 基本功能测试(流式和非流式响应) +2. 查询模式测试(local、global、naive、hybrid) +3. 错误处理测试(包括流式和非流式场景) + +所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。 +""" + import requests import json -import sseclient +import argparse +import os +import time +from typing import Dict, Any, Optional, List, Callable, Tuple +from dataclasses import dataclass, asdict +from datetime import datetime +from pathlib import Path -def test_non_stream_chat(): - """测试非流式调用 /api/chat 接口""" - url = "http://localhost:9621/api/chat" +class OutputControl: + """输出控制类,管理测试输出的详细程度""" + _verbose: bool = False + + @classmethod + def set_verbose(cls, verbose: bool) -> None: + """设置输出详细程度 + + Args: + verbose: True 为详细模式,False 为静默模式 + """ + cls._verbose = verbose + + @classmethod + def is_verbose(cls) -> bool: + """获取当前输出模式 + + Returns: + 当前是否为详细模式 + """ + return cls._verbose + +@dataclass +class TestResult: + """测试结果数据类""" + name: str + success: bool + duration: float + error: Optional[str] = None + timestamp: str = "" - # 构造请求数据 - data = { + def __post_init__(self): + """初始化后设置时间戳""" + if not self.timestamp: + self.timestamp = datetime.now().isoformat() + +class TestStats: + """测试统计信息""" + def __init__(self): + self.results: List[TestResult] = [] + self.start_time = datetime.now() + + def add_result(self, result: TestResult): + """添加测试结果""" + self.results.append(result) + + def export_results(self, path: str = "test_results.json"): + """导出测试结果到 JSON 文件 + + Args: + path: 输出文件路径 + """ + results_data = { + "start_time": self.start_time.isoformat(), + "end_time": datetime.now().isoformat(), + "results": [asdict(r) for r in self.results], + "summary": { + "total": len(self.results), + "passed": sum(1 for r in self.results if r.success), + "failed": sum(1 for r in self.results if not r.success), + "total_duration": sum(r.duration for r in self.results) + } + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(results_data, f, ensure_ascii=False, indent=2) + print(f"\n测试结果已保存到: {path}") + + def print_summary(self): + """打印测试统计摘要""" + total = len(self.results) + passed = sum(1 for r in self.results if r.success) + failed = total - passed + duration = sum(r.duration for r in self.results) + + print("\n=== 测试结果摘要 ===") + print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"总用时: {duration:.2f}秒") + print(f"总计: {total} 个测试") + print(f"通过: {passed} 个") + print(f"失败: {failed} 个") + + if failed > 0: + print("\n失败的测试:") + for result in self.results: + if not result.success: + print(f"- {result.name}: {result.error}") + +# 默认配置 +DEFAULT_CONFIG = { + "server": { + "host": "localhost", + "port": 9621, "model": "lightrag:latest", + "timeout": 30, # 请求超时时间(秒) + "max_retries": 3, # 最大重试次数 + "retry_delay": 1 # 重试间隔(秒) + }, + "test_cases": { + "basic": { + "query": "孙悟空", + "stream_query": "孙悟空有什么法力,性格特征是什么" + } + } +} + +def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response: + """发送 HTTP 请求,支持重试机制 + + Args: + url: 请求 URL + data: 请求数据 + stream: 是否使用流式响应 + + Returns: + requests.Response 对象 + + Raises: + requests.exceptions.RequestException: 请求失败且重试次数用完 + """ + server_config = CONFIG["server"] + max_retries = server_config["max_retries"] + retry_delay = server_config["retry_delay"] + timeout = server_config["timeout"] + + for attempt in range(max_retries): + try: + response = requests.post( + url, + json=data, + stream=stream, + timeout=timeout + ) + return response + except requests.exceptions.RequestException as e: + if attempt == max_retries - 1: # 最后一次重试 + raise + print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}") + time.sleep(retry_delay) + +def load_config() -> Dict[str, Any]: + """加载配置文件 + + 首先尝试从当前目录的 config.json 加载, + 如果不存在则使用默认配置 + + Returns: + 配置字典 + """ + config_path = Path("config.json") + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + return DEFAULT_CONFIG + +def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: + """格式化打印 JSON 响应数据 + + Args: + data: 要打印的数据字典 + title: 打印的标题 + indent: JSON 缩进空格数 + """ + if OutputControl.is_verbose(): + if title: + print(f"\n=== {title} ===") + print(json.dumps(data, ensure_ascii=False, indent=indent)) + +# 全局配置 +CONFIG = load_config() + +def get_base_url() -> str: + """返回基础 URL""" + server = CONFIG["server"] + return f"http://{server['host']}:{server['port']}/api/chat" + +def create_request_data( + content: str, + stream: bool = False, + model: str = None +) -> Dict[str, Any]: + """创建基本的请求数据 + + Args: + content: 用户消息内容 + stream: 是否使用流式响应 + model: 模型名称 + + Returns: + 包含完整请求数据的字典 + """ + return { + "model": model or CONFIG["server"]["model"], "messages": [ { "role": "user", - "content": "孙悟空" + "content": content } ], - "stream": False + "stream": stream } + +# 全局测试统计 +STATS = TestStats() + +def run_test(func: Callable, name: str) -> None: + """运行测试并记录结果 + + Args: + func: 测试函数 + name: 测试名称 + """ + start_time = time.time() + try: + func() + duration = time.time() - start_time + STATS.add_result(TestResult(name, True, duration)) + except Exception as e: + duration = time.time() - start_time + STATS.add_result(TestResult(name, False, duration, str(e))) + raise + +def test_non_stream_chat(): + """测试非流式调用 /api/chat 接口""" + url = get_base_url() + data = create_request_data( + CONFIG["test_cases"]["basic"]["query"], + stream=False + ) # 发送请求 - response = requests.post(url, json=data) + response = make_request(url, data) # 打印响应 - print("\n=== 非流式调用响应 ===") + if OutputControl.is_verbose(): + print("\n=== 非流式调用响应 ===") response_json = response.json() - # 打印消息内容 - print("=== 响应内容 ===") - print(json.dumps({ + # 打印响应内容 + print_json_response({ "model": response_json["model"], "message": response_json["message"] - }, ensure_ascii=False, indent=2)) + }, "响应内容") # 打印性能统计 - print("\n=== 性能统计 ===") - stats = { + print_json_response({ "total_duration": response_json["total_duration"], "load_duration": response_json["load_duration"], "prompt_eval_count": response_json["prompt_eval_count"], "prompt_eval_duration": response_json["prompt_eval_duration"], "eval_count": response_json["eval_count"], "eval_duration": response_json["eval_duration"] - } - print(json.dumps(stats, ensure_ascii=False, indent=2)) + }, "性能统计") def test_stream_chat(): - """测试流式调用 /api/chat 接口""" - url = "http://localhost:9621/api/chat" + """测试流式调用 /api/chat 接口 - # 构造请求数据 - data = { + 使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。 + 响应格式: + { "model": "lightrag:latest", - "messages": [ - { - "role": "user", - "content": "孙悟空有什么法力,性格特征是什么" - } - ], - "stream": True + "created_at": "2024-01-15T00:00:00Z", + "message": { + "role": "assistant", + "content": "部分响应内容", + "images": null + }, + "done": false } - # 发送请求并获取 SSE 流 - response = requests.post(url, json=data, stream=True) - client = sseclient.SSEClient(response) + 最后一条消息会包含性能统计信息,done 为 true。 + """ + url = get_base_url() + data = create_request_data( + CONFIG["test_cases"]["basic"]["stream_query"], + stream=True + ) - print("\n=== 流式调用响应 ===") + # 发送请求并获取流式响应 + response = make_request(url, data, stream=True) + + if OutputControl.is_verbose(): + print("\n=== 流式调用响应 ===") output_buffer = [] try: - for event in client.events(): - try: - data = json.loads(event.data) - if data.get("done", True): # 如果是完成标记 - if "total_duration" in data: # 最终的性能统计消息 - print("\n=== 性能统计 ===") - print(json.dumps(data, ensure_ascii=False, indent=2)) - break - else: # 正常的内容消息 - message = data.get("message", {}) - content = message.get("content", "") - if content: # 只收集非空内容 - output_buffer.append(content) - except json.JSONDecodeError: - print("Error decoding JSON from SSE event") + for line in response.iter_lines(): + if line: # 跳过空行 + try: + # 解码并解析 JSON + data = json.loads(line.decode('utf-8')) + if data.get("done", True): # 如果是完成标记 + if "total_duration" in data: # 最终的性能统计消息 + print_json_response(data, "性能统计") + break + else: # 正常的内容消息 + message = data.get("message", {}) + content = message.get("content", "") + if content: # 只收集非空内容 + output_buffer.append(content) + print(content, end="", flush=True) # 实时打印内容 + except json.JSONDecodeError: + print("Error decoding JSON from response line") finally: response.close() # 确保关闭响应连接 - # 一次性打印所有收集到的内容 - print("".join(output_buffer)) + # 打印一个换行 + print() + +def test_query_modes(): + """测试不同的查询模式前缀 + + 支持的查询模式: + - /local: 本地检索模式,只在相关度高的文档中搜索 + - /global: 全局检索模式,在所有文档中搜索 + - /naive: 朴素模式,不使用任何优化策略 + - /hybrid: 混合模式(默认),结合多种策略 + + 每个模式都会返回相同格式的响应,但检索策略不同。 + """ + url = get_base_url() + modes = ["local", "global", "naive", "hybrid"] # 支持的查询模式 + + for mode in modes: + if OutputControl.is_verbose(): + print(f"\n=== 测试 /{mode} 模式 ===") + data = create_request_data( + f"/{mode} 孙悟空的特点", + stream=False + ) + + # 发送请求 + response = make_request(url, data) + response_json = response.json() + + # 打印响应内容 + print_json_response({ + "model": response_json["model"], + "message": response_json["message"] + }) + +def create_error_test_data(error_type: str) -> Dict[str, Any]: + """创建用于错误测试的请求数据 + + Args: + error_type: 错误类型,支持: + - empty_messages: 空消息列表 + - invalid_role: 无效的角色字段 + - missing_content: 缺少内容字段 + + Returns: + 包含错误数据的请求字典 + """ + error_data = { + "empty_messages": { + "model": "lightrag:latest", + "messages": [], + "stream": True + }, + "invalid_role": { + "model": "lightrag:latest", + "messages": [ + { + "invalid_role": "user", + "content": "测试消息" + } + ], + "stream": True + }, + "missing_content": { + "model": "lightrag:latest", + "messages": [ + { + "role": "user" + } + ], + "stream": True + } + } + return error_data.get(error_type, error_data["empty_messages"]) + +def test_stream_error_handling(): + """测试流式响应的错误处理 + + 测试场景: + 1. 空消息列表 + 2. 消息格式错误(缺少必需字段) + + 错误响应会立即返回,不会建立流式连接。 + 状态码应该是 4xx,并返回详细的错误信息。 + """ + url = get_base_url() + + if OutputControl.is_verbose(): + print("\n=== 测试流式响应错误处理 ===") + + # 测试空消息列表 + if OutputControl.is_verbose(): + print("\n--- 测试空消息列表(流式)---") + data = create_error_test_data("empty_messages") + response = make_request(url, data, stream=True) + print(f"状态码: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "错误信息") + response.close() + + # 测试无效角色字段 + if OutputControl.is_verbose(): + print("\n--- 测试无效角色字段(流式)---") + data = create_error_test_data("invalid_role") + response = make_request(url, data, stream=True) + print(f"状态码: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "错误信息") + response.close() + + # 测试缺少内容字段 + if OutputControl.is_verbose(): + print("\n--- 测试缺少内容字段(流式)---") + data = create_error_test_data("missing_content") + response = make_request(url, data, stream=True) + print(f"状态码: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "错误信息") + response.close() + +def test_error_handling(): + """测试非流式响应的错误处理 + + 测试场景: + 1. 空消息列表 + 2. 消息格式错误(缺少必需字段) + + 错误响应格式: + { + "detail": "错误描述" + } + + 所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。 + """ + url = get_base_url() + + if OutputControl.is_verbose(): + print("\n=== 测试错误处理 ===") + + # 测试空消息列表 + if OutputControl.is_verbose(): + print("\n--- 测试空消息列表 ---") + data = create_error_test_data("empty_messages") + data["stream"] = False # 修改为非流式模式 + response = make_request(url, data) + print(f"状态码: {response.status_code}") + print_json_response(response.json(), "错误信息") + + # 测试无效角色字段 + if OutputControl.is_verbose(): + print("\n--- 测试无效角色字段 ---") + data = create_error_test_data("invalid_role") + data["stream"] = False # 修改为非流式模式 + response = make_request(url, data) + print(f"状态码: {response.status_code}") + print_json_response(response.json(), "错误信息") + + # 测试缺少内容字段 + if OutputControl.is_verbose(): + print("\n--- 测试缺少内容字段 ---") + data = create_error_test_data("missing_content") + data["stream"] = False # 修改为非流式模式 + response = make_request(url, data) + print(f"状态码: {response.status_code}") + print_json_response(response.json(), "错误信息") + +def get_test_cases() -> Dict[str, Callable]: + """获取所有可用的测试用例 + + Returns: + 测试名称到测试函数的映射字典 + """ + return { + "non_stream": test_non_stream_chat, + "stream": test_stream_chat, + "modes": test_query_modes, + "errors": test_error_handling, + "stream_errors": test_stream_error_handling + } + +def create_default_config(): + """创建默认配置文件""" + config_path = Path("config.json") + if not config_path.exists(): + with open(config_path, "w", encoding="utf-8") as f: + json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) + print(f"已创建默认配置文件: {config_path}") + +def parse_args() -> argparse.Namespace: + """解析命令行参数""" + parser = argparse.ArgumentParser( + description="LightRAG Ollama 兼容接口测试", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +配置文件 (config.json): + { + "server": { + "host": "localhost", # 服务器地址 + "port": 9621, # 服务器端口 + "model": "lightrag:latest" # 默认模型名称 + }, + "test_cases": { + "basic": { + "query": "测试查询", # 基本查询文本 + "stream_query": "流式查询" # 流式查询文本 + } + } + } +""" + ) + parser.add_argument( + "--tests", + nargs="+", + choices=list(get_test_cases().keys()) + ["all"], + default=["all"], + help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试" + ) + parser.add_argument( + "--init-config", + action="store_true", + help="创建默认配置文件" + ) + parser.add_argument( + "--output", + type=str, + default="test_results.json", + help="测试结果输出文件路径" + ) + parser.add_argument( + "-q", "--quiet", + action="store_true", + help="静默模式,只显示测试结果摘要" + ) + return parser.parse_args() if __name__ == "__main__": - # 先测试非流式调用 - test_non_stream_chat() + args = parse_args() - # 再测试流式调用 - test_stream_chat() + # 设置输出模式 + OutputControl.set_verbose(not args.quiet) + + # 如果指定了创建配置文件 + if args.init_config: + create_default_config() + exit(0) + + test_cases = get_test_cases() + + try: + if "all" in args.tests: + # 运行所有测试 + if OutputControl.is_verbose(): + print("\n【基本功能测试】") + run_test(test_non_stream_chat, "非流式调用测试") + run_test(test_stream_chat, "流式调用测试") + + if OutputControl.is_verbose(): + print("\n【查询模式测试】") + run_test(test_query_modes, "查询模式测试") + + if OutputControl.is_verbose(): + print("\n【错误处理测试】") + run_test(test_error_handling, "错误处理测试") + run_test(test_stream_error_handling, "流式错误处理测试") + else: + # 运行指定的测试 + for test_name in args.tests: + if OutputControl.is_verbose(): + print(f"\n【运行测试: {test_name}】") + run_test(test_cases[test_name], test_name) + except Exception as e: + print(f"\n发生错误: {str(e)}") + finally: + # 打印并导出测试统计 + STATS.print_summary() + STATS.export_results(args.output) From e978a15593f3bbb45e9fd21135f52e179e3e6ca6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 22:39:41 +0800 Subject: [PATCH 24/42] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=9C=AA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=9A=84=E5=AF=BC=E5=85=A5=E5=B9=B6=E7=AE=80=E5=8C=96?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test_lightrag_ollama_chat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index ae32a3b8..b9902e77 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -12,9 +12,8 @@ LightRAG Ollama 兼容接口测试脚本 import requests import json import argparse -import os import time -from typing import Dict, Any, Optional, List, Callable, Tuple +from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path From 9632a8f0dc539f0b733d93a5c939e7a4a484d175 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 23:09:50 +0800 Subject: [PATCH 25/42] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E5=91=BD=E4=B8=AD=E7=BC=93=E5=AD=98=E6=97=B6=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E5=93=8D=E5=BA=94=E6=9C=AA=E9=81=B5=E5=BE=AAOllma=E8=A7=84?= =?UTF-8?q?=E8=8C=83=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - rag返回结果未字符串时,响应分两次发送 - 第一次发送查询内容 - 第二次发送统计信息 --- lightrag/api/lightrag_ollama.py | 21 ++++++++++++++++++--- test_lightrag_ollama_chat.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 714731eb..5768fc42 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -709,16 +709,31 @@ def create_app(args): try: # 确保 response 是异步生成器 if isinstance(response, str): - # 如果是字符串,作为单个完整响应发送 + # 如果是字符串,分两次发送 + # 第一次发送查询内容 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { - "role": "assistant", + "role": "assistant", "content": response, "images": None }, - "done": True + "done": False + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + # 第二次发送统计信息 + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": 1, + "load_duration": 1, + "prompt_eval_count": 999, + "prompt_eval_duration": 1, + "eval_count": 999, + "eval_duration": 1 } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index b9902e77..02b51b22 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -119,7 +119,7 @@ DEFAULT_CONFIG = { "test_cases": { "basic": { "query": "孙悟空", - "stream_query": "孙悟空有什么法力,性格特征是什么" + "stream_query": "孙悟空" } } } From ea22d62c25cd086c83b16b9b5343e3cc260defec Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 15 Jan 2025 23:11:15 +0800 Subject: [PATCH 26/42] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E8=B0=83=E8=AF=95?= =?UTF-8?q?=E6=97=A5=E5=BF=97=E6=89=93=E5=8D=B0=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 5768fc42..22f2e017 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -674,9 +674,9 @@ def create_app(args): @app.post("/api/chat") async def chat(raw_request: Request, request: OllamaChatRequest): - # 打印原始请求数据 - body = await raw_request.body() - logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}") + # # 打印原始请求数据 + # body = await raw_request.body() + # logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}") """Handle chat completion requests""" try: # 获取所有消息内容 From 350e080ec139e881c4359b53247097bbaf4a7edf Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 01:11:59 +0800 Subject: [PATCH 27/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=E9=85=8D=E7=BD=AE=E5=92=8C=E5=91=BD=E4=BB=A4?= =?UTF-8?q?=E8=A1=8C=E5=8F=82=E6=95=B0=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除冗余的stream_query配置 - 统一使用query作为测试查询内容 - 新增--ask参数覆盖查询内容 - 调整命令行参数顺序 --- test_lightrag_ollama_chat.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 02b51b22..e0ed5342 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -118,8 +118,7 @@ DEFAULT_CONFIG = { }, "test_cases": { "basic": { - "query": "孙悟空", - "stream_query": "孙悟空" + "query": "唐僧有几个徒弟" } } } @@ -292,7 +291,7 @@ def test_stream_chat(): """ url = get_base_url() data = create_request_data( - CONFIG["test_cases"]["basic"]["stream_query"], + CONFIG["test_cases"]["basic"]["query"], stream=True ) @@ -344,7 +343,7 @@ def test_query_modes(): if OutputControl.is_verbose(): print(f"\n=== 测试 /{mode} 模式 ===") data = create_request_data( - f"/{mode} 孙悟空的特点", + f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) @@ -534,11 +533,14 @@ def parse_args() -> argparse.Namespace: """ ) parser.add_argument( - "--tests", - nargs="+", - choices=list(get_test_cases().keys()) + ["all"], - default=["all"], - help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试" + "-q", "--quiet", + action="store_true", + help="静默模式,只显示测试结果摘要" + ) + parser.add_argument( + "-a", "--ask", + type=str, + help="指定查询内容,会覆盖配置文件中的查询设置" ) parser.add_argument( "--init-config", @@ -552,9 +554,11 @@ def parse_args() -> argparse.Namespace: help="测试结果输出文件路径" ) parser.add_argument( - "-q", "--quiet", - action="store_true", - help="静默模式,只显示测试结果摘要" + "--tests", + nargs="+", + choices=list(get_test_cases().keys()) + ["all"], + default=["all"], + help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试" ) return parser.parse_args() @@ -564,6 +568,10 @@ if __name__ == "__main__": # 设置输出模式 OutputControl.set_verbose(not args.quiet) + # 如果指定了查询内容,更新配置 + if args.ask: + CONFIG["test_cases"]["basic"]["query"] = args.ask + # 如果指定了创建配置文件 if args.init_config: create_default_config() From 7658f4cbf01aceca6d623ec9b71ff0c5f8223775 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 01:16:53 +0800 Subject: [PATCH 28/42] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=80=A7=E8=83=BD?= =?UTF-8?q?=E7=BB=9F=E8=AE=A1=E7=9A=84=E6=89=93=E5=8D=B0=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 注释掉非流式聊天中的性能统计打印 - 注释掉流式聊天中的性能统计打印 - 保持代码简洁,减少冗余输出 --- test_lightrag_ollama_chat.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index e0ed5342..069859f0 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -261,15 +261,15 @@ def test_non_stream_chat(): "message": response_json["message"] }, "响应内容") - # 打印性能统计 - print_json_response({ - "total_duration": response_json["total_duration"], - "load_duration": response_json["load_duration"], - "prompt_eval_count": response_json["prompt_eval_count"], - "prompt_eval_duration": response_json["prompt_eval_duration"], - "eval_count": response_json["eval_count"], - "eval_duration": response_json["eval_duration"] - }, "性能统计") + # # 打印性能统计 + # print_json_response({ + # "total_duration": response_json["total_duration"], + # "load_duration": response_json["load_duration"], + # "prompt_eval_count": response_json["prompt_eval_count"], + # "prompt_eval_duration": response_json["prompt_eval_duration"], + # "eval_count": response_json["eval_count"], + # "eval_duration": response_json["eval_duration"] + # }, "性能统计") def test_stream_chat(): """测试流式调用 /api/chat 接口 @@ -309,7 +309,7 @@ def test_stream_chat(): data = json.loads(line.decode('utf-8')) if data.get("done", True): # 如果是完成标记 if "total_duration" in data: # 最终的性能统计消息 - print_json_response(data, "性能统计") + # print_json_response(data, "性能统计") break else: # 正常的内容消息 message = data.get("message", {}) From 5e4c9dd4d7c6b617341fdc92da1eb3e44efdb751 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 03:26:47 +0800 Subject: [PATCH 29/42] =?UTF-8?q?=E7=A7=BB=E9=99=A4api=20server=20?= =?UTF-8?q?=E5=AF=B9=20lightrag-hku=20=E7=9A=84=E4=BE=9D=E8=B5=96=EF=BC=88?= =?UTF-8?q?=E8=A7=A3=E5=86=B3=E9=9D=9E=E7=BC=96=E8=BE=91=E8=B0=83=E8=AF=95?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E5=AE=89=E8=A3=85=E6=97=A0=E6=B3=95=E5=90=AF?= =?UTF-8?q?=E5=8A=A8api=E6=9C=8D=E5=8A=A1=E7=9A=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index 9154809c..74776828 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -1,7 +1,6 @@ aioboto3 ascii_colors fastapi -lightrag-hku nano_vectordb nest_asyncio numpy From 9c69438c3e5584e44f11fb231adea5ba9d03d165 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 17:50:28 +0800 Subject: [PATCH 30/42] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E5=90=AF=E5=8A=A8=E8=84=9A=E6=9C=AC=EF=BC=9A=E6=BF=80=E6=B4=BB?= =?UTF-8?q?=E8=99=9A=E6=8B=9F=E7=8E=AF=E5=A2=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- start-server.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/start-server.sh b/start-server.sh index a250a3c3..4e143f37 100755 --- a/start-server.sh +++ b/start-server.sh @@ -1 +1,3 @@ +. venv/bin/activate + lightrag-ollama --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024 From 95ff048a9ec2fbbdc0d529c5ef5d9400389b0bf6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 19:42:34 +0800 Subject: [PATCH 31/42] =?UTF-8?q?=E4=B8=BAOllama=20API=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E7=BB=9F=E8=AE=A1=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增token估算函数 - 记录流式响应时间 - 计算输入输出token数 - 统计响应生成时间 - 返回详细的性能指标 --- lightrag/api/lightrag_ollama.py | 95 ++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 18 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 22f2e017..1c117e86 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -3,6 +3,8 @@ from pydantic import BaseModel import logging import argparse import json +import time +import re from typing import List, Dict, Any, Optional from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, ollama_embedding @@ -24,6 +26,20 @@ from starlette.status import HTTP_403_FORBIDDEN from dotenv import load_dotenv load_dotenv() +def estimate_tokens(text: str) -> int: + """估算文本的token数量 + 中文每字约1.5个token + 英文每字约0.25个token + """ + # 使用正则表达式分别匹配中文字符和非中文字符 + chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) + non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text)) + + # 计算估算的token数量 + tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 + + return int(tokens) + # Constants for model information LIGHTRAG_NAME = "lightrag" LIGHTRAG_TAG = "latest" @@ -690,6 +706,12 @@ def create_app(args): # 解析查询模式 cleaned_query, mode = parse_query_mode(query) + # 开始计时 + start_time = time.time_ns() + + # 计算输入token数量 + prompt_tokens = estimate_tokens(cleaned_query) + # 调用RAG进行查询 query_param = QueryParam( mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid @@ -707,9 +729,17 @@ def create_app(args): async def stream_generator(): try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + # 确保 response 是异步生成器 if isinstance(response, str): # 如果是字符串,分两次发送 + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + # 第一次发送查询内容 data = { "model": LIGHTRAG_MODEL, @@ -723,23 +753,38 @@ def create_app(args): } yield f"{json.dumps(data, ensure_ascii=False)}\n" + # 计算各项指标 + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time # 总时间 + prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 + eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 + # 第二次发送统计信息 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "done": True, - "total_duration": 1, - "load_duration": 1, - "prompt_eval_count": 999, - "prompt_eval_duration": 1, - "eval_count": 999, - "eval_duration": 1 + "total_duration": total_time, # 总时间 + "load_duration": 0, # 加载时间为0 + "prompt_eval_count": prompt_tokens, # 输入token数 + "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 + "eval_count": completion_tokens, # 输出token数 + "eval_duration": eval_time # 生成响应的时间 } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: # 流式响应 async for chunk in response: if chunk: # 只发送非空内容 + # 记录第一个chunk的时间 + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + # 更新最后一个chunk的时间 + last_chunk_time = time.time_ns() + + # 累积响应内容 + total_response += chunk data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -752,17 +797,23 @@ def create_app(args): } yield f"{json.dumps(data, ensure_ascii=False)}\n" + # 计算各项指标 + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time # 总时间 + prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 + eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 + # 发送完成标记,包含性能统计信息 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "done": True, - "total_duration": 1, # 由于我们没有实际统计这些指标,暂时使用默认值 - "load_duration": 1, - "prompt_eval_count": 999, - "prompt_eval_duration": 1, - "eval_count": 999, - "eval_duration": 1 + "total_duration": total_time, # 总时间 + "load_duration": 0, # 加载时间为0 + "prompt_eval_count": prompt_tokens, # 输入token数 + "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 + "eval_count": completion_tokens, # 输出token数 + "eval_duration": eval_time # 生成响应的时间 } yield f"{json.dumps(data, ensure_ascii=False)}\n" return # 确保生成器在发送完成标记后立即结束 @@ -784,14 +835,22 @@ def create_app(args): ) else: # 非流式响应 + first_chunk_time = time.time_ns() response_text = await rag.aquery( cleaned_query, param=query_param ) + last_chunk_time = time.time_ns() # 确保响应不为空 if not response_text: response_text = "No response generated" + + # 计算各项指标 + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time # 总时间 + prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 + eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 # 构造响应,包含性能统计信息 return { @@ -803,12 +862,12 @@ def create_app(args): "images": None }, "done": True, - "total_duration": 1, # 由于我们没有实际统计这些指标,暂时使用默认值 - "load_duration": 1, - "prompt_eval_count": 999, - "prompt_eval_duration": 1, - "eval_count": 999, - "eval_duration": 1 + "total_duration": total_time, # 总时间 + "load_duration": 0, # 加载时间为0 + "prompt_eval_count": prompt_tokens, # 输入token数 + "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 + "eval_count": completion_tokens, # 输出token数 + "eval_duration": eval_time # 生成响应的时间 } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) From 68de8f6d04e539fb5f5a6f26a465467774d117f4 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 20:14:43 +0800 Subject: [PATCH 32/42] delete private test file --- examples/lightrag_yangdx.py | 82 ------------------------------------- 1 file changed, 82 deletions(-) delete mode 100644 examples/lightrag_yangdx.py diff --git a/examples/lightrag_yangdx.py b/examples/lightrag_yangdx.py deleted file mode 100644 index 2691deac..00000000 --- a/examples/lightrag_yangdx.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -# import asyncio -# import inspect -import logging -from dotenv import load_dotenv -from lightrag import LightRAG, QueryParam -from lightrag.llm import openai_complete_if_cache, ollama_embedding -from lightrag.utils import EmbeddingFunc - -load_dotenv() - -WORKING_DIR = "./examples/output" - -logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - "deepseek-chat", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=os.getenv("DEEPSEEK_API_KEY"), - base_url=os.getenv("DEEPSEEK_ENDPOINT"), - **kwargs, - ) - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=lambda texts: ollama_embedding( - texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" - ), - ), -) - -with open("./examples/input/book.txt", "r", encoding="utf-8") as f: - rag.insert(f.read()) - -# Perform naive search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) -) - -# Perform local search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) -) - -# Perform global search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) -) - -# Perform hybrid search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) -) - -# # stream response -# resp = rag.query( -# "What are the top themes in this story?", -# param=QueryParam(mode="hybrid", stream=True), -# ) - - -# async def print_stream(stream): -# async for chunk in stream: -# print(chunk, end="", flush=True) - - -# if inspect.isasyncgen(resp): -# asyncio.run(print_stream(resp)) -# else: -# print(resp) From b38a98a51496d60e474fd78560eb150747bc3558 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 20:22:53 +0800 Subject: [PATCH 33/42] =?UTF-8?q?=E9=BB=98=E8=AE=A4=E4=B8=8D=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E6=B5=8B=E8=AF=95=E7=BB=93=E6=9E=9C=E5=88=B0=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test_lightrag_ollama_chat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 069859f0..5f8a03da 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -550,8 +550,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--output", type=str, - default="test_results.json", - help="测试结果输出文件路径" + default="", + help="测试结果输出文件路径,默认不输出到文件" ) parser.add_argument( "--tests", @@ -604,6 +604,8 @@ if __name__ == "__main__": except Exception as e: print(f"\n发生错误: {str(e)}") finally: - # 打印并导出测试统计 + # 打印测试统计 STATS.print_summary() - STATS.export_results(args.output) + # 如果指定了输出文件路径,则导出结果 + if args.output: + STATS.export_results(args.output) From ac11a7192e998ae95d008195019ba5dcb3b144b4 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 16 Jan 2025 21:04:45 +0800 Subject: [PATCH 34/42] revert changeds make by mistake --- lightrag/api/lightrag_server.py | 69 ++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 268efc1d..0d154b38 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -3,10 +3,10 @@ from pydantic import BaseModel import logging import argparse from lightrag import LightRAG, QueryParam -# from lightrag.llm import lollms_model_complete, lollms_embed -# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding -from lightrag.llm import openai_complete_if_cache, ollama_embedding -# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding +from lightrag.llm import lollms_model_complete, lollms_embed +from lightrag.llm import ollama_model_complete, ollama_embed +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding from lightrag.utils import EmbeddingFunc from typing import Optional, List, Union @@ -24,28 +24,13 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm -from dotenv import load_dotenv -load_dotenv() - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - "deepseek-chat", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=os.getenv("DEEPSEEK_API_KEY"), - base_url=os.getenv("DEEPSEEK_ENDPOINT"), - **kwargs, - ) def get_default_host(binding_type: str) -> str: default_hosts = { - "ollama": "http://m4.lan.znipower.com:11434", + "ollama": "http://localhost:11434", "lollms": "http://localhost:9600", "azure_openai": "https://api.openai.com/v1", - "openai": os.getenv("DEEPSEEK_ENDPOINT"), + "openai": "https://api.openai.com/v1", } return default_hosts.get( binding_type, "http://localhost:11434" @@ -334,12 +319,44 @@ def create_app(args): # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=llm_model_func, + llm_model_func=lollms_model_complete + if args.llm_binding == "lollms" + else ollama_model_complete + if args.llm_binding == "ollama" + else azure_openai_complete_if_cache + if args.llm_binding == "azure_openai" + else openai_complete_if_cache, + llm_model_name=args.llm_model, + llm_model_max_async=args.max_async, + llm_model_max_token_size=args.max_tokens, + llm_model_kwargs={ + "host": args.llm_binding_host, + "timeout": args.timeout, + "options": {"num_ctx": args.max_tokens}, + }, embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=lambda texts: ollama_embedding( - texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" + embedding_dim=args.embedding_dim, + max_token_size=args.max_embed_tokens, + func=lambda texts: lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.llm_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.llm_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.llm_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai ), ), ) From 34d6b85adbf25cbbc8e3e969bbe2ac954fc97027 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 01:50:07 +0800 Subject: [PATCH 35/42] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=B8=85=E7=90=86?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E5=89=8D=E7=BC=80=E6=97=B6=E6=9C=AA=E8=83=BD?= =?UTF-8?q?=E6=AD=A3=E7=A1=AE=E6=B8=85=E7=90=86=E7=A9=BA=E6=A0=BC=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 1c117e86..bd068653 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -684,7 +684,9 @@ def create_app(args): for prefix, mode in mode_map.items(): if query.startswith(prefix): - return query[len(prefix):], mode + # 移除前缀后,清理开头的额外空格 + cleaned_query = query[len(prefix):].lstrip() + return cleaned_query, mode return query, SearchMode.hybrid From 847963d19a99213a826d60e8251004bcc20ff0a8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 03:35:03 +0800 Subject: [PATCH 36/42] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20/query=20=E5=92=8C?= =?UTF-8?q?=20/query/stream=20=E7=AB=AF=E7=82=B9=E5=A4=84=E7=90=86stream?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E6=98=AF=E7=9A=84=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 51 ++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index bd068653..af3f22ee 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -489,27 +489,21 @@ def create_app(args): ), ) - if request.stream: - from fastapi.responses import StreamingResponse - - async def stream_generator(): - async for chunk in response: - yield f"{json.dumps({'response': chunk})}\n" - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type" - } - ) - else: + # 如果响应是字符串(比如命中缓存),直接返回 + if isinstance(response, str): return QueryResponse(response=response) + + # 如果是异步生成器,根据stream参数决定是否流式返回 + if request.stream: + result = "" + async for chunk in response: + result += chunk + return QueryResponse(response=result) + else: + result = "" + async for chunk in response: + result += chunk + return QueryResponse(response=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -528,8 +522,18 @@ def create_app(args): from fastapi.responses import StreamingResponse async def stream_generator(): - async for chunk in response: - yield f"{json.dumps({'response': chunk})}\n" + if isinstance(response, str): + # 如果是字符串,一次性发送 + yield f"{json.dumps({'response': response})}\n" + else: + # 如果是异步生成器,逐块发送 + try: + async for chunk in response: + if chunk: # 只发送非空内容 + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logging.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" return StreamingResponse( stream_generator(), @@ -540,7 +544,8 @@ def create_app(args): "Content-Type": "application/x-ndjson", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type" + "Access-Control-Allow-Headers": "Content-Type", + "X-Accel-Buffering": "no" # 禁用 Nginx 缓冲 } ) except Exception as e: From 3138ae7599adfc490ea4e4fbfe59612a5d5adeb1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 11:04:36 +0800 Subject: [PATCH 37/42] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9=20mix=20?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E6=A8=A1=E5=BC=8F=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/api/lightrag_ollama.py | 4 +++- test_lightrag_ollama_chat.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index af3f22ee..82329806 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -243,6 +243,7 @@ class SearchMode(str, Enum): local = "local" global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global" hybrid = "hybrid" + mix = "mix" # Ollama API compatible models class OllamaMessage(BaseModel): @@ -684,7 +685,8 @@ def create_app(args): "/local ": SearchMode.local, "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword "/naive ": SearchMode.naive, - "/hybrid ": SearchMode.hybrid + "/hybrid ": SearchMode.hybrid, + "/mix ": SearchMode.mix } for prefix, mode in mode_map.items(): diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 5f8a03da..aab059aa 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -337,7 +337,7 @@ def test_query_modes(): 每个模式都会返回相同格式的响应,但检索策略不同。 """ url = get_base_url() - modes = ["local", "global", "naive", "hybrid"] # 支持的查询模式 + modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式 for mode in modes: if OutputControl.is_verbose(): From 939e399dd4c8f8ceedebc9b5966f2390e20e3921 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 13:36:31 +0800 Subject: [PATCH 38/42] Translate comment to English --- lightrag/api/lightrag_ollama.py | 37 ++++----- test_lightrag_ollama_chat.py | 133 ++++++++++++++------------------ 2 files changed, 77 insertions(+), 93 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 82329806..bb9b1ac5 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -27,15 +27,15 @@ from dotenv import load_dotenv load_dotenv() def estimate_tokens(text: str) -> int: - """估算文本的token数量 - 中文每字约1.5个token - 英文每字约0.25个token + """Estimate the number of tokens in text + Chinese characters: approximately 1.5 tokens per character + English characters: approximately 0.25 tokens per character """ - # 使用正则表达式分别匹配中文字符和非中文字符 + # Use regex to match Chinese and non-Chinese characters separately chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text)) - # 计算估算的token数量 + # Calculate estimated token count tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 return int(tokens) @@ -241,7 +241,7 @@ class DocumentManager: class SearchMode(str, Enum): naive = "naive" local = "local" - global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global" + global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global" hybrid = "hybrid" mix = "mix" @@ -254,7 +254,7 @@ class OllamaMessage(BaseModel): class OllamaChatRequest(BaseModel): model: str = LIGHTRAG_MODEL messages: List[OllamaMessage] - stream: bool = True # 默认为流式模式 + stream: bool = True # Default to streaming mode options: Optional[Dict[str, Any]] = None class OllamaChatResponse(BaseModel): @@ -490,11 +490,11 @@ def create_app(args): ), ) - # 如果响应是字符串(比如命中缓存),直接返回 + # If response is a string (e.g. cache hit), return directly if isinstance(response, str): return QueryResponse(response=response) - # 如果是异步生成器,根据stream参数决定是否流式返回 + # If it's an async generator, decide whether to stream based on stream parameter if request.stream: result = "" async for chunk in response: @@ -511,7 +511,7 @@ def create_app(args): @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: - response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await + response = await rag.aquery( # Use aquery instead of query, and add await request.query, param=QueryParam( mode=request.mode, @@ -691,7 +691,7 @@ def create_app(args): for prefix, mode in mode_map.items(): if query.startswith(prefix): - # 移除前缀后,清理开头的额外空格 + # After removing prefix an leading spaces cleaned_query = query[len(prefix):].lstrip() return cleaned_query, mode @@ -699,17 +699,14 @@ def create_app(args): @app.post("/api/chat") async def chat(raw_request: Request, request: OllamaChatRequest): - # # 打印原始请求数据 - # body = await raw_request.body() - # logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}") """Handle chat completion requests""" try: - # 获取所有消息内容 + # Get all messages messages = request.messages if not messages: raise HTTPException(status_code=400, detail="No messages provided") - # 获取最后一条消息作为查询 + # Get the last message as query query = messages[-1].content # 解析查询模式 @@ -723,7 +720,7 @@ def create_app(args): # 调用RAG进行查询 query_param = QueryParam( - mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid + mode=mode, stream=request.stream, only_need_context=False ) @@ -731,7 +728,7 @@ def create_app(args): if request.stream: from fastapi.responses import StreamingResponse - response = await rag.aquery( # 需要 await 来获取异步生成器 + response = await rag.aquery( # Need await to get async generator cleaned_query, param=query_param ) @@ -742,9 +739,9 @@ def create_app(args): last_chunk_time = None total_response = "" - # 确保 response 是异步生成器 + # Ensure response is an async generator if isinstance(response, str): - # 如果是字符串,分两次发送 + # If it's a string, send in two parts first_chunk_time = time.time_ns() last_chunk_time = first_chunk_time total_response = response diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index aab059aa..44b4fa42 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -1,12 +1,12 @@ """ -LightRAG Ollama 兼容接口测试脚本 +LightRAG Ollama Compatibility Interface Test Script -这个脚本测试 LightRAG 的 Ollama 兼容接口,包括: -1. 基本功能测试(流式和非流式响应) -2. 查询模式测试(local、global、naive、hybrid) -3. 错误处理测试(包括流式和非流式场景) +This script tests the LightRAG's Ollama compatibility interface, including: +1. Basic functionality tests (streaming and non-streaming responses) +2. Query mode tests (local, global, naive, hybrid) +3. Error handling tests (including streaming and non-streaming scenarios) -所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。 +All responses use the JSON Lines format, complying with the Ollama API specification. """ import requests @@ -24,20 +24,10 @@ class OutputControl: @classmethod def set_verbose(cls, verbose: bool) -> None: - """设置输出详细程度 - - Args: - verbose: True 为详细模式,False 为静默模式 - """ cls._verbose = verbose @classmethod def is_verbose(cls) -> bool: - """获取当前输出模式 - - Returns: - 当前是否为详细模式 - """ return cls._verbose @dataclass @@ -48,9 +38,8 @@ class TestResult: duration: float error: Optional[str] = None timestamp: str = "" - + def __post_init__(self): - """初始化后设置时间戳""" if not self.timestamp: self.timestamp = datetime.now().isoformat() @@ -59,14 +48,13 @@ class TestStats: def __init__(self): self.results: List[TestResult] = [] self.start_time = datetime.now() - + def add_result(self, result: TestResult): - """添加测试结果""" self.results.append(result) - + def export_results(self, path: str = "test_results.json"): """导出测试结果到 JSON 文件 - + Args: path: 输出文件路径 """ @@ -81,25 +69,24 @@ class TestStats: "total_duration": sum(r.duration for r in self.results) } } - + with open(path, "w", encoding="utf-8") as f: json.dump(results_data, f, ensure_ascii=False, indent=2) print(f"\n测试结果已保存到: {path}") - + def print_summary(self): - """打印测试统计摘要""" total = len(self.results) passed = sum(1 for r in self.results if r.success) failed = total - passed duration = sum(r.duration for r in self.results) - + print("\n=== 测试结果摘要 ===") print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") print(f"总用时: {duration:.2f}秒") print(f"总计: {total} 个测试") print(f"通过: {passed} 个") print(f"失败: {failed} 个") - + if failed > 0: print("\n失败的测试:") for result in self.results: @@ -125,15 +112,15 @@ DEFAULT_CONFIG = { def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response: """发送 HTTP 请求,支持重试机制 - + Args: url: 请求 URL data: 请求数据 stream: 是否使用流式响应 - + Returns: - requests.Response 对象 - + requests.Response: 对象 + Raises: requests.exceptions.RequestException: 请求失败且重试次数用完 """ @@ -141,7 +128,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques max_retries = server_config["max_retries"] retry_delay = server_config["retry_delay"] timeout = server_config["timeout"] - + for attempt in range(max_retries): try: response = requests.post( @@ -159,10 +146,10 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques def load_config() -> Dict[str, Any]: """加载配置文件 - + 首先尝试从当前目录的 config.json 加载, 如果不存在则使用默认配置 - + Returns: 配置字典 """ @@ -174,7 +161,7 @@ def load_config() -> Dict[str, Any]: def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: """格式化打印 JSON 响应数据 - + Args: data: 要打印的数据字典 title: 打印的标题 @@ -199,12 +186,12 @@ def create_request_data( model: str = None ) -> Dict[str, Any]: """创建基本的请求数据 - + Args: content: 用户消息内容 stream: 是否使用流式响应 model: 模型名称 - + Returns: 包含完整请求数据的字典 """ @@ -224,7 +211,7 @@ STATS = TestStats() def run_test(func: Callable, name: str) -> None: """运行测试并记录结果 - + Args: func: 测试函数 name: 测试名称 @@ -246,21 +233,21 @@ def test_non_stream_chat(): CONFIG["test_cases"]["basic"]["query"], stream=False ) - + # 发送请求 response = make_request(url, data) - + # 打印响应 if OutputControl.is_verbose(): print("\n=== 非流式调用响应 ===") response_json = response.json() - + # 打印响应内容 print_json_response({ "model": response_json["model"], "message": response_json["message"] }, "响应内容") - + # # 打印性能统计 # print_json_response({ # "total_duration": response_json["total_duration"], @@ -273,7 +260,7 @@ def test_non_stream_chat(): def test_stream_chat(): """测试流式调用 /api/chat 接口 - + 使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。 响应格式: { @@ -286,7 +273,7 @@ def test_stream_chat(): }, "done": false } - + 最后一条消息会包含性能统计信息,done 为 true。 """ url = get_base_url() @@ -294,10 +281,10 @@ def test_stream_chat(): CONFIG["test_cases"]["basic"]["query"], stream=True ) - + # 发送请求并获取流式响应 response = make_request(url, data, stream=True) - + if OutputControl.is_verbose(): print("\n=== 流式调用响应 ===") output_buffer = [] @@ -321,24 +308,24 @@ def test_stream_chat(): print("Error decoding JSON from response line") finally: response.close() # 确保关闭响应连接 - + # 打印一个换行 print() def test_query_modes(): """测试不同的查询模式前缀 - + 支持的查询模式: - /local: 本地检索模式,只在相关度高的文档中搜索 - /global: 全局检索模式,在所有文档中搜索 - /naive: 朴素模式,不使用任何优化策略 - /hybrid: 混合模式(默认),结合多种策略 - + 每个模式都会返回相同格式的响应,但检索策略不同。 """ url = get_base_url() modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式 - + for mode in modes: if OutputControl.is_verbose(): print(f"\n=== 测试 /{mode} 模式 ===") @@ -346,11 +333,11 @@ def test_query_modes(): f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) - + # 发送请求 response = make_request(url, data) response_json = response.json() - + # 打印响应内容 print_json_response({ "model": response_json["model"], @@ -359,13 +346,13 @@ def test_query_modes(): def create_error_test_data(error_type: str) -> Dict[str, Any]: """创建用于错误测试的请求数据 - + Args: error_type: 错误类型,支持: - empty_messages: 空消息列表 - invalid_role: 无效的角色字段 - missing_content: 缺少内容字段 - + Returns: 包含错误数据的请求字典 """ @@ -399,19 +386,19 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: def test_stream_error_handling(): """测试流式响应的错误处理 - + 测试场景: 1. 空消息列表 2. 消息格式错误(缺少必需字段) - + 错误响应会立即返回,不会建立流式连接。 状态码应该是 4xx,并返回详细的错误信息。 """ url = get_base_url() - + if OutputControl.is_verbose(): print("\n=== 测试流式响应错误处理 ===") - + # 测试空消息列表 if OutputControl.is_verbose(): print("\n--- 测试空消息列表(流式)---") @@ -421,7 +408,7 @@ def test_stream_error_handling(): if response.status_code != 200: print_json_response(response.json(), "错误信息") response.close() - + # 测试无效角色字段 if OutputControl.is_verbose(): print("\n--- 测试无效角色字段(流式)---") @@ -444,23 +431,23 @@ def test_stream_error_handling(): def test_error_handling(): """测试非流式响应的错误处理 - + 测试场景: 1. 空消息列表 2. 消息格式错误(缺少必需字段) - + 错误响应格式: { "detail": "错误描述" } - + 所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。 """ url = get_base_url() - + if OutputControl.is_verbose(): print("\n=== 测试错误处理 ===") - + # 测试空消息列表 if OutputControl.is_verbose(): print("\n--- 测试空消息列表 ---") @@ -469,7 +456,7 @@ def test_error_handling(): response = make_request(url, data) print(f"状态码: {response.status_code}") print_json_response(response.json(), "错误信息") - + # 测试无效角色字段 if OutputControl.is_verbose(): print("\n--- 测试无效角色字段 ---") @@ -490,7 +477,7 @@ def test_error_handling(): def get_test_cases() -> Dict[str, Callable]: """获取所有可用的测试用例 - + Returns: 测试名称到测试函数的映射字典 """ @@ -564,21 +551,21 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": args = parse_args() - + # 设置输出模式 OutputControl.set_verbose(not args.quiet) - + # 如果指定了查询内容,更新配置 if args.ask: CONFIG["test_cases"]["basic"]["query"] = args.ask - + # 如果指定了创建配置文件 if args.init_config: create_default_config() exit(0) - + test_cases = get_test_cases() - + try: if "all" in args.tests: # 运行所有测试 @@ -586,11 +573,11 @@ if __name__ == "__main__": print("\n【基本功能测试】") run_test(test_non_stream_chat, "非流式调用测试") run_test(test_stream_chat, "流式调用测试") - + if OutputControl.is_verbose(): print("\n【查询模式测试】") run_test(test_query_modes, "查询模式测试") - + if OutputControl.is_verbose(): print("\n【错误处理测试】") run_test(test_error_handling, "错误处理测试") From 48f70ff8b48f53570367bca74e5205a26184c014 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 14:07:17 +0800 Subject: [PATCH 39/42] Translate unit test comment and promts to English --- test_lightrag_ollama_chat.py | 343 ++++++++++++++++------------------- 1 file changed, 161 insertions(+), 182 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 44b4fa42..4f6cab29 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -19,7 +19,7 @@ from datetime import datetime from pathlib import Path class OutputControl: - """输出控制类,管理测试输出的详细程度""" + """Output control class, manages the verbosity of test output""" _verbose: bool = False @classmethod @@ -32,7 +32,7 @@ class OutputControl: @dataclass class TestResult: - """测试结果数据类""" + """Test result data class""" name: str success: bool duration: float @@ -44,7 +44,7 @@ class TestResult: self.timestamp = datetime.now().isoformat() class TestStats: - """测试统计信息""" + """Test statistics""" def __init__(self): self.results: List[TestResult] = [] self.start_time = datetime.now() @@ -53,10 +53,9 @@ class TestStats: self.results.append(result) def export_results(self, path: str = "test_results.json"): - """导出测试结果到 JSON 文件 - + """Export test results to a JSON file Args: - path: 输出文件路径 + path: Output file path """ results_data = { "start_time": self.start_time.isoformat(), @@ -72,7 +71,7 @@ class TestStats: with open(path, "w", encoding="utf-8") as f: json.dump(results_data, f, ensure_ascii=False, indent=2) - print(f"\n测试结果已保存到: {path}") + print(f"\nTest results saved to: {path}") def print_summary(self): total = len(self.results) @@ -80,28 +79,27 @@ class TestStats: failed = total - passed duration = sum(r.duration for r in self.results) - print("\n=== 测试结果摘要 ===") - print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") - print(f"总用时: {duration:.2f}秒") - print(f"总计: {total} 个测试") - print(f"通过: {passed} 个") - print(f"失败: {failed} 个") + print("\n=== Test Summary ===") + print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"Total duration: {duration:.2f} seconds") + print(f"Total tests: {total}") + print(f"Passed: {passed}") + print(f"Failed: {failed}") if failed > 0: - print("\n失败的测试:") + print("\nFailed tests:") for result in self.results: if not result.success: print(f"- {result.name}: {result.error}") -# 默认配置 DEFAULT_CONFIG = { "server": { "host": "localhost", "port": 9621, "model": "lightrag:latest", - "timeout": 30, # 请求超时时间(秒) - "max_retries": 3, # 最大重试次数 - "retry_delay": 1 # 重试间隔(秒) + "timeout": 30, + "max_retries": 3, + "retry_delay": 1 }, "test_cases": { "basic": { @@ -111,18 +109,16 @@ DEFAULT_CONFIG = { } def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response: - """发送 HTTP 请求,支持重试机制 - + """Send an HTTP request with retry mechanism Args: - url: 请求 URL - data: 请求数据 - stream: 是否使用流式响应 - + url: Request URL + data: Request data + stream: Whether to use streaming response Returns: - requests.Response: 对象 + requests.Response: Response object Raises: - requests.exceptions.RequestException: 请求失败且重试次数用完 + requests.exceptions.RequestException: Request failed after all retries """ server_config = CONFIG["server"] max_retries = server_config["max_retries"] @@ -139,19 +135,18 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques ) return response except requests.exceptions.RequestException as e: - if attempt == max_retries - 1: # 最后一次重试 + if attempt == max_retries - 1: # Last retry raise - print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}") + print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}") time.sleep(retry_delay) def load_config() -> Dict[str, Any]: - """加载配置文件 - - 首先尝试从当前目录的 config.json 加载, - 如果不存在则使用默认配置 + """Load configuration file + First try to load from config.json in the current directory, + if it doesn't exist, use the default configuration Returns: - 配置字典 + Configuration dictionary """ config_path = Path("config.json") if config_path.exists(): @@ -160,23 +155,22 @@ def load_config() -> Dict[str, Any]: return DEFAULT_CONFIG def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: - """格式化打印 JSON 响应数据 - + """Format and print JSON response data Args: - data: 要打印的数据字典 - title: 打印的标题 - indent: JSON 缩进空格数 + data: Data dictionary to print + title: Title to print + indent: Number of spaces for JSON indentation """ if OutputControl.is_verbose(): if title: print(f"\n=== {title} ===") print(json.dumps(data, ensure_ascii=False, indent=indent)) -# 全局配置 +# Global configuration CONFIG = load_config() def get_base_url() -> str: - """返回基础 URL""" + """Return the base URL""" server = CONFIG["server"] return f"http://{server['host']}:{server['port']}/api/chat" @@ -185,15 +179,13 @@ def create_request_data( stream: bool = False, model: str = None ) -> Dict[str, Any]: - """创建基本的请求数据 - + """Create basic request data Args: - content: 用户消息内容 - stream: 是否使用流式响应 - model: 模型名称 - + content: User message content + stream: Whether to use streaming response + model: Model name Returns: - 包含完整请求数据的字典 + Dictionary containing complete request data """ return { "model": model or CONFIG["server"]["model"], @@ -206,15 +198,14 @@ def create_request_data( "stream": stream } -# 全局测试统计 +# Global test statistics STATS = TestStats() def run_test(func: Callable, name: str) -> None: - """运行测试并记录结果 - + """Run a test and record the results Args: - func: 测试函数 - name: 测试名称 + func: Test function + name: Test name """ start_time = time.time() try: @@ -227,54 +218,43 @@ def run_test(func: Callable, name: str) -> None: raise def test_non_stream_chat(): - """测试非流式调用 /api/chat 接口""" + """Test non-streaming call to /api/chat endpoint""" url = get_base_url() data = create_request_data( CONFIG["test_cases"]["basic"]["query"], stream=False ) - # 发送请求 + # Send request response = make_request(url, data) - # 打印响应 + # Print response if OutputControl.is_verbose(): - print("\n=== 非流式调用响应 ===") + print("\n=== Non-streaming call response ===") response_json = response.json() - # 打印响应内容 + # Print response content print_json_response({ "model": response_json["model"], "message": response_json["message"] - }, "响应内容") - - # # 打印性能统计 - # print_json_response({ - # "total_duration": response_json["total_duration"], - # "load_duration": response_json["load_duration"], - # "prompt_eval_count": response_json["prompt_eval_count"], - # "prompt_eval_duration": response_json["prompt_eval_duration"], - # "eval_count": response_json["eval_count"], - # "eval_duration": response_json["eval_duration"] - # }, "性能统计") - + }, "Response content") def test_stream_chat(): - """测试流式调用 /api/chat 接口 + """Test streaming call to /api/chat endpoint - 使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。 - 响应格式: + Use JSON Lines format to process streaming responses, each line is a complete JSON object. + Response format: { "model": "lightrag:latest", "created_at": "2024-01-15T00:00:00Z", "message": { "role": "assistant", - "content": "部分响应内容", + "content": "Partial response content", "images": null }, "done": false } - 最后一条消息会包含性能统计信息,done 为 true。 + The last message will contain performance statistics, with done set to true. """ url = get_base_url() data = create_request_data( @@ -282,79 +262,79 @@ def test_stream_chat(): stream=True ) - # 发送请求并获取流式响应 + # Send request and get streaming response response = make_request(url, data, stream=True) if OutputControl.is_verbose(): - print("\n=== 流式调用响应 ===") + print("\n=== Streaming call response ===") output_buffer = [] try: for line in response.iter_lines(): - if line: # 跳过空行 + if line: # Skip empty lines try: - # 解码并解析 JSON + # Decode and parse JSON data = json.loads(line.decode('utf-8')) - if data.get("done", True): # 如果是完成标记 - if "total_duration" in data: # 最终的性能统计消息 - # print_json_response(data, "性能统计") + if data.get("done", True): # If it's the completion marker + if "total_duration" in data: # Final performance statistics message + # print_json_response(data, "Performance statistics") break - else: # 正常的内容消息 + else: # Normal content message message = data.get("message", {}) content = message.get("content", "") - if content: # 只收集非空内容 + if content: # Only collect non-empty content output_buffer.append(content) - print(content, end="", flush=True) # 实时打印内容 + print(content, end="", flush=True) # Print content in real-time except json.JSONDecodeError: print("Error decoding JSON from response line") finally: - response.close() # 确保关闭响应连接 + response.close() # Ensure the response connection is closed - # 打印一个换行 + # Print a newline print() def test_query_modes(): - """测试不同的查询模式前缀 + """Test different query mode prefixes - 支持的查询模式: - - /local: 本地检索模式,只在相关度高的文档中搜索 - - /global: 全局检索模式,在所有文档中搜索 - - /naive: 朴素模式,不使用任何优化策略 - - /hybrid: 混合模式(默认),结合多种策略 + Supported query modes: + - /local: Local retrieval mode, searches only in highly relevant documents + - /global: Global retrieval mode, searches across all documents + - /naive: Naive mode, does not use any optimization strategies + - /hybrid: Hybrid mode (default), combines multiple strategies + - /mix: Mix mode - 每个模式都会返回相同格式的响应,但检索策略不同。 + Each mode will return responses in the same format, but with different retrieval strategies. """ url = get_base_url() - modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式 + modes = ["local", "global", "naive", "hybrid", "mix"] for mode in modes: if OutputControl.is_verbose(): - print(f"\n=== 测试 /{mode} 模式 ===") + print(f"\n=== Testing /{mode} mode ===") data = create_request_data( f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) - # 发送请求 + # Send request response = make_request(url, data) response_json = response.json() - # 打印响应内容 + # Print response content print_json_response({ "model": response_json["model"], "message": response_json["message"] }) def create_error_test_data(error_type: str) -> Dict[str, Any]: - """创建用于错误测试的请求数据 - + """Create request data for error testing Args: - error_type: 错误类型,支持: - - empty_messages: 空消息列表 - - invalid_role: 无效的角色字段 - - missing_content: 缺少内容字段 + error_type: Error type, supported: + - empty_messages: Empty message list + - invalid_role: Invalid role field + - missing_content: Missing content field Returns: - 包含错误数据的请求字典 + Request dictionary containing error data """ error_data = { "empty_messages": { @@ -367,7 +347,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: "messages": [ { "invalid_role": "user", - "content": "测试消息" + "content": "Test message" } ], "stream": True @@ -385,101 +365,100 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: return error_data.get(error_type, error_data["empty_messages"]) def test_stream_error_handling(): - """测试流式响应的错误处理 + """Test error handling for streaming responses - 测试场景: - 1. 空消息列表 - 2. 消息格式错误(缺少必需字段) + Test scenarios: + 1. Empty message list + 2. Message format error (missing required fields) - 错误响应会立即返回,不会建立流式连接。 - 状态码应该是 4xx,并返回详细的错误信息。 + Error responses should be returned immediately without establishing a streaming connection. + The status code should be 4xx, and detailed error information should be returned. """ url = get_base_url() if OutputControl.is_verbose(): - print("\n=== 测试流式响应错误处理 ===") + print("\n=== Testing streaming response error handling ===") - # 测试空消息列表 + # Test empty message list if OutputControl.is_verbose(): - print("\n--- 测试空消息列表(流式)---") + print("\n--- Testing empty message list (streaming) ---") data = create_error_test_data("empty_messages") response = make_request(url, data, stream=True) - print(f"状态码: {response.status_code}") + print(f"Status code: {response.status_code}") if response.status_code != 200: - print_json_response(response.json(), "错误信息") + print_json_response(response.json(), "Error message") response.close() - # 测试无效角色字段 + # Test invalid role field if OutputControl.is_verbose(): - print("\n--- 测试无效角色字段(流式)---") + print("\n--- Testing invalid role field (streaming) ---") data = create_error_test_data("invalid_role") response = make_request(url, data, stream=True) - print(f"状态码: {response.status_code}") + print(f"Status code: {response.status_code}") if response.status_code != 200: - print_json_response(response.json(), "错误信息") + print_json_response(response.json(), "Error message") response.close() - # 测试缺少内容字段 + # Test missing content field if OutputControl.is_verbose(): - print("\n--- 测试缺少内容字段(流式)---") + print("\n--- Testing missing content field (streaming) ---") data = create_error_test_data("missing_content") response = make_request(url, data, stream=True) - print(f"状态码: {response.status_code}") + print(f"Status code: {response.status_code}") if response.status_code != 200: - print_json_response(response.json(), "错误信息") + print_json_response(response.json(), "Error message") response.close() def test_error_handling(): - """测试非流式响应的错误处理 + """Test error handling for non-streaming responses - 测试场景: - 1. 空消息列表 - 2. 消息格式错误(缺少必需字段) + Test scenarios: + 1. Empty message list + 2. Message format error (missing required fields) - 错误响应格式: + Error response format: { - "detail": "错误描述" + "detail": "Error description" } - 所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。 + All errors should return appropriate HTTP status codes and clear error messages. """ url = get_base_url() if OutputControl.is_verbose(): - print("\n=== 测试错误处理 ===") + print("\n=== Testing error handling ===") - # 测试空消息列表 + # Test empty message list if OutputControl.is_verbose(): - print("\n--- 测试空消息列表 ---") + print("\n--- Testing empty message list ---") data = create_error_test_data("empty_messages") - data["stream"] = False # 修改为非流式模式 + data["stream"] = False # Change to non-streaming mode response = make_request(url, data) - print(f"状态码: {response.status_code}") - print_json_response(response.json(), "错误信息") + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") - # 测试无效角色字段 + # Test invalid role field if OutputControl.is_verbose(): - print("\n--- 测试无效角色字段 ---") + print("\n--- Testing invalid role field ---") data = create_error_test_data("invalid_role") - data["stream"] = False # 修改为非流式模式 + data["stream"] = False # Change to non-streaming mode response = make_request(url, data) - print(f"状态码: {response.status_code}") - print_json_response(response.json(), "错误信息") + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") - # 测试缺少内容字段 + # Test missing content field if OutputControl.is_verbose(): - print("\n--- 测试缺少内容字段 ---") + print("\n--- Testing missing content field ---") data = create_error_test_data("missing_content") - data["stream"] = False # 修改为非流式模式 + data["stream"] = False # Change to non-streaming mode response = make_request(url, data) - print(f"状态码: {response.status_code}") - print_json_response(response.json(), "错误信息") + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") def get_test_cases() -> Dict[str, Callable]: - """获取所有可用的测试用例 - + """Get all available test cases Returns: - 测试名称到测试函数的映射字典 + A dictionary mapping test names to test functions """ return { "non_stream": test_non_stream_chat, @@ -490,30 +469,30 @@ def get_test_cases() -> Dict[str, Callable]: } def create_default_config(): - """创建默认配置文件""" + """Create a default configuration file""" config_path = Path("config.json") if not config_path.exists(): with open(config_path, "w", encoding="utf-8") as f: json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) - print(f"已创建默认配置文件: {config_path}") + print(f"Default configuration file created: {config_path}") def parse_args() -> argparse.Namespace: - """解析命令行参数""" + """Parse command line arguments""" parser = argparse.ArgumentParser( - description="LightRAG Ollama 兼容接口测试", + description="LightRAG Ollama Compatibility Interface Testing", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" -配置文件 (config.json): +Configuration file (config.json): { "server": { - "host": "localhost", # 服务器地址 - "port": 9621, # 服务器端口 - "model": "lightrag:latest" # 默认模型名称 + "host": "localhost", # Server address + "port": 9621, # Server port + "model": "lightrag:latest" # Default model name }, "test_cases": { "basic": { - "query": "测试查询", # 基本查询文本 - "stream_query": "流式查询" # 流式查询文本 + "query": "Test query", # Basic query text + "stream_query": "Stream query" # Stream query text } } } @@ -522,44 +501,44 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "-q", "--quiet", action="store_true", - help="静默模式,只显示测试结果摘要" + help="Silent mode, only display test result summary" ) parser.add_argument( "-a", "--ask", type=str, - help="指定查询内容,会覆盖配置文件中的查询设置" + help="Specify query content, which will override the query settings in the configuration file" ) parser.add_argument( "--init-config", action="store_true", - help="创建默认配置文件" + help="Create default configuration file" ) parser.add_argument( "--output", type=str, default="", - help="测试结果输出文件路径,默认不输出到文件" + help="Test result output file path, default is not to output to a file" ) parser.add_argument( "--tests", nargs="+", choices=list(get_test_cases().keys()) + ["all"], default=["all"], - help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试" + help="Test cases to run, options: %(choices)s. Use 'all' to run all tests" ) return parser.parse_args() if __name__ == "__main__": args = parse_args() - # 设置输出模式 + # Set output mode OutputControl.set_verbose(not args.quiet) - # 如果指定了查询内容,更新配置 + # If query content is specified, update the configuration if args.ask: CONFIG["test_cases"]["basic"]["query"] = args.ask - # 如果指定了创建配置文件 + # If specified to create a configuration file if args.init_config: create_default_config() exit(0) @@ -568,31 +547,31 @@ if __name__ == "__main__": try: if "all" in args.tests: - # 运行所有测试 + # Run all tests if OutputControl.is_verbose(): - print("\n【基本功能测试】") - run_test(test_non_stream_chat, "非流式调用测试") - run_test(test_stream_chat, "流式调用测试") + print("\n【Basic Functionality Tests】") + run_test(test_non_stream_chat, "Non-streaming Call Test") + run_test(test_stream_chat, "Streaming Call Test") if OutputControl.is_verbose(): - print("\n【查询模式测试】") - run_test(test_query_modes, "查询模式测试") + print("\n【Query Mode Tests】") + run_test(test_query_modes, "Query Mode Test") if OutputControl.is_verbose(): - print("\n【错误处理测试】") - run_test(test_error_handling, "错误处理测试") - run_test(test_stream_error_handling, "流式错误处理测试") + print("\n【Error Handling Tests】") + run_test(test_error_handling, "Error Handling Test") + run_test(test_stream_error_handling, "Streaming Error Handling Test") else: - # 运行指定的测试 + # Run specified tests for test_name in args.tests: if OutputControl.is_verbose(): - print(f"\n【运行测试: {test_name}】") + print(f"\n【Running Test: {test_name}】") run_test(test_cases[test_name], test_name) except Exception as e: - print(f"\n发生错误: {str(e)}") + print(f"\nAn error occurred: {str(e)}") finally: - # 打印测试统计 + # Print test statistics STATS.print_summary() - # 如果指定了输出文件路径,则导出结果 + # If an output file path is specified, export the results if args.output: STATS.export_results(args.output) From fa9765ecd94560dcd3cf447b235ec2ec70ff0701 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 14:20:55 +0800 Subject: [PATCH 40/42] pre-commit run --all-files --- lightrag/api/lightrag_ollama.py | 165 ++++++++++++++++++-------------- test_lightrag_ollama_chat.py | 147 ++++++++++++++-------------- 2 files changed, 163 insertions(+), 149 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index bb9b1ac5..af991c19 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -24,22 +24,25 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN from dotenv import load_dotenv + load_dotenv() + def estimate_tokens(text: str) -> int: """Estimate the number of tokens in text Chinese characters: approximately 1.5 tokens per character English characters: approximately 0.25 tokens per character """ # Use regex to match Chinese and non-Chinese characters separately - chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) - non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text)) - + chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) + non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) + # Calculate estimated token count tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 - + return int(tokens) + # Constants for model information LIGHTRAG_NAME = "lightrag" LIGHTRAG_TAG = "latest" @@ -48,6 +51,7 @@ LIGHTRAG_SIZE = 7365960935 LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" LIGHTRAG_DIGEST = "sha256:lightrag" + async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: @@ -61,6 +65,7 @@ async def llm_model_func( **kwargs, ) + def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": "http://m4.lan.znipower.com:11434", @@ -245,27 +250,32 @@ class SearchMode(str, Enum): hybrid = "hybrid" mix = "mix" + # Ollama API compatible models class OllamaMessage(BaseModel): role: str content: str images: Optional[List[str]] = None + class OllamaChatRequest(BaseModel): model: str = LIGHTRAG_MODEL messages: List[OllamaMessage] stream: bool = True # Default to streaming mode options: Optional[Dict[str, Any]] = None + class OllamaChatResponse(BaseModel): model: str created_at: str message: OllamaMessage done: bool + class OllamaVersionResponse(BaseModel): version: str + class OllamaModelDetails(BaseModel): parent_model: str format: str @@ -274,6 +284,7 @@ class OllamaModelDetails(BaseModel): parameter_size: str quantization_level: str + class OllamaModel(BaseModel): name: str model: str @@ -282,9 +293,11 @@ class OllamaModel(BaseModel): modified_at: str details: OllamaModelDetails + class OllamaTagResponse(BaseModel): models: List[OllamaModel] + # Original LightRAG models class QueryRequest(BaseModel): query: str @@ -292,9 +305,11 @@ class QueryRequest(BaseModel): stream: bool = False only_need_context: bool = False + class QueryResponse(BaseModel): response: str + class InsertTextRequest(BaseModel): text: str description: Optional[str] = None @@ -395,7 +410,9 @@ def create_app(args): embedding_dim=1024, max_token_size=8192, func=lambda texts: ollama_embedding( - texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" + texts, + embed_model="bge-m3:latest", + host="http://m4.lan.znipower.com:11434", ), ), ) @@ -493,7 +510,7 @@ def create_app(args): # If response is a string (e.g. cache hit), return directly if isinstance(response, str): return QueryResponse(response=response) - + # If it's an async generator, decide whether to stream based on stream parameter if request.stream: result = "" @@ -546,8 +563,8 @@ def create_app(args): "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type", - "X-Accel-Buffering": "no" # 禁用 Nginx 缓冲 - } + "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲 + }, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -652,29 +669,29 @@ def create_app(args): @app.get("/api/version") async def get_version(): """Get Ollama version information""" - return OllamaVersionResponse( - version="0.5.4" - ) + return OllamaVersionResponse(version="0.5.4") @app.get("/api/tags") async def get_tags(): """Get available models""" return OllamaTagResponse( - models=[{ - "name": LIGHTRAG_MODEL, - "model": LIGHTRAG_MODEL, - "size": LIGHTRAG_SIZE, - "digest": LIGHTRAG_DIGEST, - "modified_at": LIGHTRAG_CREATED_AT, - "details": { - "parent_model": "", - "format": "gguf", - "family": LIGHTRAG_NAME, - "families": [LIGHTRAG_NAME], - "parameter_size": "13B", - "quantization_level": "Q4_0" - } - }] + models=[ + { + "name": LIGHTRAG_MODEL, + "model": LIGHTRAG_MODEL, + "size": LIGHTRAG_SIZE, + "digest": LIGHTRAG_DIGEST, + "modified_at": LIGHTRAG_CREATED_AT, + "details": { + "parent_model": "", + "format": "gguf", + "family": LIGHTRAG_NAME, + "families": [LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0", + }, + } + ] ) def parse_query_mode(query: str) -> tuple[str, SearchMode]: @@ -686,15 +703,15 @@ def create_app(args): "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword "/naive ": SearchMode.naive, "/hybrid ": SearchMode.hybrid, - "/mix ": SearchMode.mix + "/mix ": SearchMode.mix, } - + for prefix, mode in mode_map.items(): if query.startswith(prefix): # After removing prefix an leading spaces - cleaned_query = query[len(prefix):].lstrip() + cleaned_query = query[len(prefix) :].lstrip() return cleaned_query, mode - + return query, SearchMode.hybrid @app.post("/api/chat") @@ -705,32 +722,29 @@ def create_app(args): messages = request.messages if not messages: raise HTTPException(status_code=400, detail="No messages provided") - + # Get the last message as query query = messages[-1].content - + # 解析查询模式 cleaned_query, mode = parse_query_mode(query) - + # 开始计时 start_time = time.time_ns() - + # 计算输入token数量 prompt_tokens = estimate_tokens(cleaned_query) - + # 调用RAG进行查询 query_param = QueryParam( - mode=mode, - stream=request.stream, - only_need_context=False + mode=mode, stream=request.stream, only_need_context=False ) - + if request.stream: from fastapi.responses import StreamingResponse - + response = await rag.aquery( # Need await to get async generator - cleaned_query, - param=query_param + cleaned_query, param=query_param ) async def stream_generator(): @@ -738,33 +752,37 @@ def create_app(args): first_chunk_time = None last_chunk_time = None total_response = "" - + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts first_chunk_time = time.time_ns() last_chunk_time = first_chunk_time total_response = response - + # 第一次发送查询内容 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { - "role": "assistant", + "role": "assistant", "content": response, - "images": None + "images": None, }, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + # 计算各项指标 completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time # 总时间 - prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 - eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 - + prompt_eval_time = ( + first_chunk_time - start_time + ) # 首个响应之前的时间 + eval_time = ( + last_chunk_time - first_chunk_time + ) # 生成响应的时间 + # 第二次发送统计信息 data = { "model": LIGHTRAG_MODEL, @@ -775,7 +793,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time # 生成响应的时间 + "eval_duration": eval_time, # 生成响应的时间 } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: @@ -785,10 +803,10 @@ def create_app(args): # 记录第一个chunk的时间 if first_chunk_time is None: first_chunk_time = time.time_ns() - + # 更新最后一个chunk的时间 last_chunk_time = time.time_ns() - + # 累积响应内容 total_response += chunk data = { @@ -797,18 +815,22 @@ def create_app(args): "message": { "role": "assistant", "content": chunk, - "images": None + "images": None, }, - "done": False + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - + # 计算各项指标 completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time # 总时间 - prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 - eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 - + prompt_eval_time = ( + first_chunk_time - start_time + ) # 首个响应之前的时间 + eval_time = ( + last_chunk_time - first_chunk_time + ) # 生成响应的时间 + # 发送完成标记,包含性能统计信息 data = { "model": LIGHTRAG_MODEL, @@ -819,14 +841,14 @@ def create_app(args): "prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time # 生成响应的时间 + "eval_duration": eval_time, # 生成响应的时间 } yield f"{json.dumps(data, ensure_ascii=False)}\n" return # 确保生成器在发送完成标记后立即结束 except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise - + return StreamingResponse( stream_generator(), media_type="application/x-ndjson", @@ -836,28 +858,25 @@ def create_app(args): "Content-Type": "application/x-ndjson", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type" - } + "Access-Control-Allow-Headers": "Content-Type", + }, ) else: # 非流式响应 first_chunk_time = time.time_ns() - response_text = await rag.aquery( - cleaned_query, - param=query_param - ) + response_text = await rag.aquery(cleaned_query, param=query_param) last_chunk_time = time.time_ns() - + # 确保响应不为空 if not response_text: response_text = "No response generated" - + # 计算各项指标 completion_tokens = estimate_tokens(str(response_text)) total_time = last_chunk_time - start_time # 总时间 prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 - + # 构造响应,包含性能统计信息 return { "model": LIGHTRAG_MODEL, @@ -865,7 +884,7 @@ def create_app(args): "message": { "role": "assistant", "content": str(response_text), # 确保转换为字符串 - "images": None + "images": None, }, "done": True, "total_duration": total_time, # 总时间 @@ -873,7 +892,7 @@ def create_app(args): "prompt_eval_count": prompt_tokens, # 输入token数 "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time # 生成响应的时间 + "eval_duration": eval_time, # 生成响应的时间 } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 4f6cab29..96aee692 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -18,8 +18,10 @@ from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path + class OutputControl: """Output control class, manages the verbosity of test output""" + _verbose: bool = False @classmethod @@ -30,9 +32,11 @@ class OutputControl: def is_verbose(cls) -> bool: return cls._verbose + @dataclass class TestResult: """Test result data class""" + name: str success: bool duration: float @@ -43,8 +47,10 @@ class TestResult: if not self.timestamp: self.timestamp = datetime.now().isoformat() + class TestStats: """Test statistics""" + def __init__(self): self.results: List[TestResult] = [] self.start_time = datetime.now() @@ -65,8 +71,8 @@ class TestStats: "total": len(self.results), "passed": sum(1 for r in self.results if r.success), "failed": sum(1 for r in self.results if not r.success), - "total_duration": sum(r.duration for r in self.results) - } + "total_duration": sum(r.duration for r in self.results), + }, } with open(path, "w", encoding="utf-8") as f: @@ -92,6 +98,7 @@ class TestStats: if not result.success: print(f"- {result.name}: {result.error}") + DEFAULT_CONFIG = { "server": { "host": "localhost", @@ -99,16 +106,15 @@ DEFAULT_CONFIG = { "model": "lightrag:latest", "timeout": 30, "max_retries": 3, - "retry_delay": 1 + "retry_delay": 1, }, - "test_cases": { - "basic": { - "query": "唐僧有几个徒弟" - } - } + "test_cases": {"basic": {"query": "唐僧有几个徒弟"}}, } -def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response: + +def make_request( + url: str, data: Dict[str, Any], stream: bool = False +) -> requests.Response: """Send an HTTP request with retry mechanism Args: url: Request URL @@ -127,12 +133,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques for attempt in range(max_retries): try: - response = requests.post( - url, - json=data, - stream=stream, - timeout=timeout - ) + response = requests.post(url, json=data, stream=stream, timeout=timeout) return response except requests.exceptions.RequestException as e: if attempt == max_retries - 1: # Last retry @@ -140,6 +141,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}") time.sleep(retry_delay) + def load_config() -> Dict[str, Any]: """Load configuration file @@ -154,6 +156,7 @@ def load_config() -> Dict[str, Any]: return json.load(f) return DEFAULT_CONFIG + def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: """Format and print JSON response data Args: @@ -166,18 +169,19 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) print(f"\n=== {title} ===") print(json.dumps(data, ensure_ascii=False, indent=indent)) + # Global configuration CONFIG = load_config() + def get_base_url() -> str: """Return the base URL""" server = CONFIG["server"] return f"http://{server['host']}:{server['port']}/api/chat" + def create_request_data( - content: str, - stream: bool = False, - model: str = None + content: str, stream: bool = False, model: str = None ) -> Dict[str, Any]: """Create basic request data Args: @@ -189,18 +193,15 @@ def create_request_data( """ return { "model": model or CONFIG["server"]["model"], - "messages": [ - { - "role": "user", - "content": content - } - ], - "stream": stream + "messages": [{"role": "user", "content": content}], + "stream": stream, } + # Global test statistics STATS = TestStats() + def run_test(func: Callable, name: str) -> None: """Run a test and record the results Args: @@ -217,13 +218,11 @@ def run_test(func: Callable, name: str) -> None: STATS.add_result(TestResult(name, False, duration, str(e))) raise + def test_non_stream_chat(): """Test non-streaming call to /api/chat endpoint""" url = get_base_url() - data = create_request_data( - CONFIG["test_cases"]["basic"]["query"], - stream=False - ) + data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) # Send request response = make_request(url, data) @@ -234,10 +233,12 @@ def test_non_stream_chat(): response_json = response.json() # Print response content - print_json_response({ - "model": response_json["model"], - "message": response_json["message"] - }, "Response content") + print_json_response( + {"model": response_json["model"], "message": response_json["message"]}, + "Response content", + ) + + def test_stream_chat(): """Test streaming call to /api/chat endpoint @@ -257,10 +258,7 @@ def test_stream_chat(): The last message will contain performance statistics, with done set to true. """ url = get_base_url() - data = create_request_data( - CONFIG["test_cases"]["basic"]["query"], - stream=True - ) + data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True) # Send request and get streaming response response = make_request(url, data, stream=True) @@ -273,9 +271,11 @@ def test_stream_chat(): if line: # Skip empty lines try: # Decode and parse JSON - data = json.loads(line.decode('utf-8')) + data = json.loads(line.decode("utf-8")) if data.get("done", True): # If it's the completion marker - if "total_duration" in data: # Final performance statistics message + if ( + "total_duration" in data + ): # Final performance statistics message # print_json_response(data, "Performance statistics") break else: # Normal content message @@ -283,7 +283,9 @@ def test_stream_chat(): content = message.get("content", "") if content: # Only collect non-empty content output_buffer.append(content) - print(content, end="", flush=True) # Print content in real-time + print( + content, end="", flush=True + ) # Print content in real-time except json.JSONDecodeError: print("Error decoding JSON from response line") finally: @@ -292,6 +294,7 @@ def test_stream_chat(): # Print a newline print() + def test_query_modes(): """Test different query mode prefixes @@ -311,8 +314,7 @@ def test_query_modes(): if OutputControl.is_verbose(): print(f"\n=== Testing /{mode} mode ===") data = create_request_data( - f"/{mode} {CONFIG['test_cases']['basic']['query']}", - stream=False + f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False ) # Send request @@ -320,10 +322,10 @@ def test_query_modes(): response_json = response.json() # Print response content - print_json_response({ - "model": response_json["model"], - "message": response_json["message"] - }) + print_json_response( + {"model": response_json["model"], "message": response_json["message"]} + ) + def create_error_test_data(error_type: str) -> Dict[str, Any]: """Create request data for error testing @@ -337,33 +339,21 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: Request dictionary containing error data """ error_data = { - "empty_messages": { - "model": "lightrag:latest", - "messages": [], - "stream": True - }, + "empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True}, "invalid_role": { "model": "lightrag:latest", - "messages": [ - { - "invalid_role": "user", - "content": "Test message" - } - ], - "stream": True + "messages": [{"invalid_role": "user", "content": "Test message"}], + "stream": True, }, "missing_content": { "model": "lightrag:latest", - "messages": [ - { - "role": "user" - } - ], - "stream": True - } + "messages": [{"role": "user"}], + "stream": True, + }, } return error_data.get(error_type, error_data["empty_messages"]) + def test_stream_error_handling(): """Test error handling for streaming responses @@ -409,6 +399,7 @@ def test_stream_error_handling(): print_json_response(response.json(), "Error message") response.close() + def test_error_handling(): """Test error handling for non-streaming responses @@ -455,6 +446,7 @@ def test_error_handling(): print(f"Status code: {response.status_code}") print_json_response(response.json(), "Error message") + def get_test_cases() -> Dict[str, Callable]: """Get all available test cases Returns: @@ -465,9 +457,10 @@ def get_test_cases() -> Dict[str, Callable]: "stream": test_stream_chat, "modes": test_query_modes, "errors": test_error_handling, - "stream_errors": test_stream_error_handling + "stream_errors": test_stream_error_handling, } + def create_default_config(): """Create a default configuration file""" config_path = Path("config.json") @@ -476,6 +469,7 @@ def create_default_config(): json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) print(f"Default configuration file created: {config_path}") + def parse_args() -> argparse.Namespace: """Parse command line arguments""" parser = argparse.ArgumentParser( @@ -496,38 +490,39 @@ Configuration file (config.json): } } } -""" +""", ) parser.add_argument( - "-q", "--quiet", + "-q", + "--quiet", action="store_true", - help="Silent mode, only display test result summary" + help="Silent mode, only display test result summary", ) parser.add_argument( - "-a", "--ask", + "-a", + "--ask", type=str, - help="Specify query content, which will override the query settings in the configuration file" + help="Specify query content, which will override the query settings in the configuration file", ) parser.add_argument( - "--init-config", - action="store_true", - help="Create default configuration file" + "--init-config", action="store_true", help="Create default configuration file" ) parser.add_argument( "--output", type=str, default="", - help="Test result output file path, default is not to output to a file" + help="Test result output file path, default is not to output to a file", ) parser.add_argument( "--tests", nargs="+", choices=list(get_test_cases().keys()) + ["all"], default=["all"], - help="Test cases to run, options: %(choices)s. Use 'all' to run all tests" + help="Test cases to run, options: %(choices)s. Use 'all' to run all tests", ) return parser.parse_args() + if __name__ == "__main__": args = parse_args() From a5618790403e93d5bebfa3cb0207d30a4dbbdae9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 14:27:27 +0800 Subject: [PATCH 41/42] Translate comments to English --- lightrag/api/lightrag_ollama.py | 81 ++++++++++++++------------------- 1 file changed, 34 insertions(+), 47 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index af991c19..3856488e 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -541,13 +541,13 @@ def create_app(args): async def stream_generator(): if isinstance(response, str): - # 如果是字符串,一次性发送 + # If it's a string, send it all at once yield f"{json.dumps({'response': response})}\n" else: - # 如果是异步生成器,逐块发送 + # If it's an async generator, send chunks one by one try: async for chunk in response: - if chunk: # 只发送非空内容 + if chunk: # Only send non-empty content yield f"{json.dumps({'response': chunk})}\n" except Exception as e: logging.error(f"Streaming error: {str(e)}") @@ -563,7 +563,7 @@ def create_app(args): "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type", - "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲 + "X-Accel-Buffering": "no", # Disable Nginx buffering }, ) except Exception as e: @@ -760,7 +760,6 @@ def create_app(args): last_chunk_time = first_chunk_time total_response = response - # 第一次发送查询内容 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, @@ -773,41 +772,35 @@ def create_app(args): } yield f"{json.dumps(data, ensure_ascii=False)}\n" - # 计算各项指标 completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time # 总时间 + total_time = last_chunk_time - start_time prompt_eval_time = ( first_chunk_time - start_time - ) # 首个响应之前的时间 + ) eval_time = ( last_chunk_time - first_chunk_time - ) # 生成响应的时间 + ) - # 第二次发送统计信息 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "done": True, - "total_duration": total_time, # 总时间 - "load_duration": 0, # 加载时间为0 - "prompt_eval_count": prompt_tokens, # 输入token数 - "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 - "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time, # 生成响应的时间 + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: - # 流式响应 async for chunk in response: - if chunk: # 只发送非空内容 - # 记录第一个chunk的时间 + if chunk: if first_chunk_time is None: first_chunk_time = time.time_ns() - # 更新最后一个chunk的时间 last_chunk_time = time.time_ns() - # 累积响应内容 total_response += chunk data = { "model": LIGHTRAG_MODEL, @@ -821,30 +814,28 @@ def create_app(args): } yield f"{json.dumps(data, ensure_ascii=False)}\n" - # 计算各项指标 completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time # 总时间 + total_time = last_chunk_time - start_time prompt_eval_time = ( first_chunk_time - start_time - ) # 首个响应之前的时间 + ) eval_time = ( last_chunk_time - first_chunk_time - ) # 生成响应的时间 + ) - # 发送完成标记,包含性能统计信息 data = { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "done": True, - "total_duration": total_time, # 总时间 - "load_duration": 0, # 加载时间为0 - "prompt_eval_count": prompt_tokens, # 输入token数 - "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 - "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time, # 生成响应的时间 + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - return # 确保生成器在发送完成标记后立即结束 + return # Ensure the generator ends immediately after sending the completion marker except Exception as e: logging.error(f"Error in stream_generator: {str(e)}") raise @@ -862,37 +853,33 @@ def create_app(args): }, ) else: - # 非流式响应 first_chunk_time = time.time_ns() response_text = await rag.aquery(cleaned_query, param=query_param) last_chunk_time = time.time_ns() - # 确保响应不为空 if not response_text: response_text = "No response generated" - # 计算各项指标 completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time # 总时间 - prompt_eval_time = first_chunk_time - start_time # 首个响应之前的时间 - eval_time = last_chunk_time - first_chunk_time # 生成响应的时间 + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - # 构造响应,包含性能统计信息 return { "model": LIGHTRAG_MODEL, "created_at": LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": str(response_text), # 确保转换为字符串 + "content": str(response_text), "images": None, }, "done": True, - "total_duration": total_time, # 总时间 - "load_duration": 0, # 加载时间为0 - "prompt_eval_count": prompt_tokens, # 输入token数 - "prompt_eval_duration": prompt_eval_time, # 首个响应之前的时间 - "eval_count": completion_tokens, # 输出token数 - "eval_duration": eval_time, # 生成响应的时间 + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) From fde0aa32c7d0cda13b734879f823432f90e14c41 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 17 Jan 2025 14:28:24 +0800 Subject: [PATCH 42/42] pre-commit run --all-files --- lightrag/api/lightrag_ollama.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py index 3856488e..fc7ae29c 100644 --- a/lightrag/api/lightrag_ollama.py +++ b/lightrag/api/lightrag_ollama.py @@ -774,12 +774,8 @@ def create_app(args): completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time - prompt_eval_time = ( - first_chunk_time - start_time - ) - eval_time = ( - last_chunk_time - first_chunk_time - ) + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time data = { "model": LIGHTRAG_MODEL, @@ -816,12 +812,8 @@ def create_app(args): completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time - prompt_eval_time = ( - first_chunk_time - start_time - ) - eval_time = ( - last_chunk_time - first_chunk_time - ) + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time data = { "model": LIGHTRAG_MODEL,