From 57b015bee18314abd8b12436e68899de8f499a5d Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Wed, 5 Feb 2025 03:22:22 +0800 Subject: [PATCH 1/9] fix doc_key filtering logic to handle dict status --- lightrag/lightrag.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3014f737..420b82eb 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -372,12 +372,23 @@ class LightRAG: # 3. Filter out already processed documents # _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) - _add_doc_keys = { - doc_id - for doc_id in new_docs.keys() - if (current_doc := await self.doc_status.get_by_id(doc_id)) is None - or current_doc.status == DocStatus.FAILED - } + _add_doc_keys = set() + for doc_id in new_docs.keys(): + current_doc = await self.doc_status.get_by_id(doc_id) + + if current_doc is None: + _add_doc_keys.add(doc_id) + continue # skip to the next doc_id + + status = None + if isinstance(current_doc, dict): + status = current_doc["status"] + else: + status = current_doc.status + + if status == DocStatus.FAILED: + _add_doc_keys.add(doc_id) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if not new_docs: From e2d164e8c8bba027b3f46ab74a8384984a2d6c54 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Wed, 5 Feb 2025 03:52:59 +0800 Subject: [PATCH 2/9] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ee5519f5..06e914ba 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News +- [x] [2025.02.05]🎯📢Our team has released [VideoRAG](https://github.com/HKUDS/VideoRAG) for processing and understanding extremely long-context videos. - [x] [2025.01.13]🎯📢Our team has released [MiniRAG](https://github.com/HKUDS/MiniRAG) making RAG simpler with small models. - [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage). - [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). From 69f200faf2f755b709c07b4557714fe8f3ace930 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Feb 2025 09:46:56 +0800 Subject: [PATCH 3/9] feat: improve error handling for streaming responses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add CancelledError handling for streams • Send error details to client in JSON • Add error status codes and messages • Always send final completion marker • Refactor stream generator error handling --- lightrag/api/lightrag_server.py | 115 +++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 40 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ec58f552..27042a27 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -12,7 +12,7 @@ from fastapi import ( # Add this to store progress globally from typing import Dict import threading - +import asyncio import json import os @@ -1718,11 +1718,11 @@ def create_app(args): ) async def stream_generator(): + first_chunk_time = None + last_chunk_time = None + total_response = "" + try: - first_chunk_time = None - last_chunk_time = None - total_response = "" - # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts @@ -1760,46 +1760,81 @@ def create_app(args): } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() - last_chunk_time = time.time_ns() + last_chunk_time = time.time_ns() - total_response += chunk - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + total_response += chunk + data = { + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except asyncio.CancelledError: + error_data = { + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, + "error": { + "code": "STREAM_CANCELLED", + "message": "Stream was cancelled by server" + }, + "done": False + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + raise - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + if last_chunk_time is not None: + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - return # Ensure the generator ends immediately after sending the completion marker except Exception as e: - logging.error(f"Error in stream_generator: {str(e)}") + error_msg = f"Error in stream_generator: {str(e)}" + logging.error(error_msg) + + # 发送错误消息给客户端 + error_data = { + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, + "error": { + "code": "STREAM_ERROR", + "message": error_msg + }, + "done": False + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # 确保发送结束标记 + final_data = { + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" raise return StreamingResponse( From ff40e61fad530ef516934d7881f9ad1f6f999311 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Feb 2025 09:47:39 +0800 Subject: [PATCH 4/9] Fix linting --- lightrag/api/lightrag_server.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 27042a27..131be01f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1721,7 +1721,7 @@ def create_app(args): first_chunk_time = None last_chunk_time = None total_response = "" - + try: # Ensure response is an async generator if isinstance(response, str): @@ -1786,9 +1786,9 @@ def create_app(args): "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "error": { "code": "STREAM_CANCELLED", - "message": "Stream was cancelled by server" + "message": "Stream was cancelled by server", }, - "done": False + "done": False, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" raise @@ -1815,24 +1815,21 @@ def create_app(args): except Exception as e: error_msg = f"Error in stream_generator: {str(e)}" logging.error(error_msg) - + # 发送错误消息给客户端 error_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": { - "code": "STREAM_ERROR", - "message": error_msg - }, - "done": False + "error": {"code": "STREAM_ERROR", "message": error_msg}, + "done": False, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - + # 确保发送结束标记 final_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True + "done": True, } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" raise From 24effb127dbff8ddfc9868984460bd99ebbf65cd Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Feb 2025 10:44:48 +0800 Subject: [PATCH 5/9] Improve error handling and response consistency in streaming endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add error message forwarding to client • Handle stream cancellations gracefully • Add logging for stream errors • Ensure clean stream termination • Add try-catch in OpenAI streaming --- lightrag/api/lightrag_server.py | 38 ++++++++++++++++++++++++--------- lightrag/llm/openai.py | 18 ++++++++++------ 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 131be01f..d8412a13 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1780,18 +1780,34 @@ def create_app(args): "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - except asyncio.CancelledError: + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logging.error(f"Stream error: {error_msg}") + + # Send error message to client error_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "error": { - "code": "STREAM_CANCELLED", - "message": "Stream was cancelled by server", + "code": "STREAM_ERROR", + "message": error_msg }, - "done": False, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - raise + + # Send final message to close the stream + final_data = { + "model": ollama_server_infos.LIGHTRAG_MODEL, + "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return if last_chunk_time is not None: completion_tokens = estimate_tokens(total_response) @@ -1816,23 +1832,25 @@ def create_app(args): error_msg = f"Error in stream_generator: {str(e)}" logging.error(error_msg) - # 发送错误消息给客户端 + # Send error message to client error_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": {"code": "STREAM_ERROR", "message": error_msg}, - "done": False, + "error": { + "code": "STREAM_ERROR", + "message": error_msg + }, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - # 确保发送结束标记 + # Ensure sending end marker final_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, "done": True, } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - raise + return return StreamingResponse( stream_generator(), diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 11ba69c0..4eaca093 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -125,13 +125,17 @@ async def openai_complete_if_cache( if hasattr(response, "__aiter__"): async def inner(): - async for chunk in response: - content = chunk.choices[0].delta.content - if content is None: - continue - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - yield content + try: + async for chunk in response: + content = chunk.choices[0].delta.content + if content is None: + continue + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + except Exception as e: + logger.error(f"Error in stream response: {str(e)}") + raise return inner() else: From f1ea7f7415ca53adf2576dab03d50ed137d5d61d Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Feb 2025 11:07:31 +0800 Subject: [PATCH 6/9] update error response format in streaming API to a normal message. So user can get what's going on. --- lightrag/api/lightrag_server.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d8412a13..c9144d0e 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1793,10 +1793,12 @@ def create_app(args): error_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": { - "code": "STREAM_ERROR", - "message": error_msg + "message": { + "role": "assistant", + "content": f"\n\nError: {error_msg}", + "images": None }, + "done": False } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" From f77faf80237bd474e4c5f3fd33addfe2698e95fd Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Feb 2025 12:36:52 +0800 Subject: [PATCH 7/9] Fix linting --- lightrag/api/lightrag_server.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index aadd1c09..345136aa 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1793,9 +1793,9 @@ def create_app(args): error_msg = "Stream was cancelled by server" else: error_msg = f"Provider error: {error_msg}" - + logging.error(f"Stream error: {error_msg}") - + # Send error message to client error_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, @@ -1803,12 +1803,12 @@ def create_app(args): "message": { "role": "assistant", "content": f"\n\nError: {error_msg}", - "images": None + "images": None, }, - "done": False + "done": False, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - + # Send final message to close the stream final_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, @@ -1845,10 +1845,7 @@ def create_app(args): error_data = { "model": ollama_server_infos.LIGHTRAG_MODEL, "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": { - "code": "STREAM_ERROR", - "message": error_msg - }, + "error": {"code": "STREAM_ERROR", "message": error_msg}, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" From f703334ce4603654b6bfc1b5bd0c4fddd4e0e2ad Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Feb 2025 22:15:14 +0800 Subject: [PATCH 8/9] Split the Ollama API implementation to a separated file --- lightrag/api/lightrag_server.py | 607 +++----------------------------- lightrag/api/ollama_api.py | 554 +++++++++++++++++++++++++++++ 2 files changed, 594 insertions(+), 567 deletions(-) create mode 100644 lightrag/api/ollama_api.py diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 345136aa..2ee2838e 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -4,28 +4,20 @@ from fastapi import ( File, UploadFile, Form, - Request, BackgroundTasks, ) -# Backend (Python) -# Add this to store progress globally -from typing import Dict import threading -import asyncio -import json import os - +import json +import re from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel import logging import argparse -import time -import re -from typing import List, Any, Optional, Union +from typing import List, Any, Optional, Union, Dict +from pydantic import BaseModel from lightrag import LightRAG, QueryParam from lightrag.api import __api_version__ - from lightrag.utils import EmbeddingFunc from enum import Enum from pathlib import Path @@ -34,20 +26,30 @@ import aiofiles from ascii_colors import trace_exception, ASCIIColors import sys import configparser - from fastapi import Depends, Security from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager - from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm - from dotenv import load_dotenv +from .ollama_api import ( + OllamaAPI, +) +from .ollama_api import ollama_server_infos # Load environment variables load_dotenv() +class RAGStorageConfig: + KV_STORAGE = "JsonKVStorage" + DOC_STATUS_STORAGE = "JsonDocStatusStorage" + GRAPH_STORAGE = "NetworkXStorage" + VECTOR_STORAGE = "NanoVectorDBStorage" + +# Initialize rag storage config +rag_storage_config = RAGStorageConfig() + # Global progress tracker scan_progress: Dict = { "is_scanning": False, @@ -76,24 +78,6 @@ def estimate_tokens(text: str) -> int: return int(tokens) -class OllamaServerInfos: - # Constants for emulated Ollama model information - LIGHTRAG_NAME = "lightrag" - LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") - LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" - LIGHTRAG_SIZE = 7365960935 # it's a dummy value - LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" - LIGHTRAG_DIGEST = "sha256:lightrag" - - KV_STORAGE = "JsonKVStorage" - DOC_STATUS_STORAGE = "JsonDocStatusStorage" - GRAPH_STORAGE = "NetworkXStorage" - VECTOR_STORAGE = "NanoVectorDBStorage" - - -# Add infos -ollama_server_infos = OllamaServerInfos() - # read config.ini config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -101,8 +85,8 @@ config.read("config.ini", "utf-8") redis_uri = config.get("redis", "uri", fallback=None) if redis_uri: os.environ["REDIS_URI"] = redis_uri - ollama_server_infos.KV_STORAGE = "RedisKVStorage" - ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage" + rag_storage_config.KV_STORAGE = "RedisKVStorage" + rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage" # Neo4j config neo4j_uri = config.get("neo4j", "uri", fallback=None) @@ -112,7 +96,7 @@ if neo4j_uri: os.environ["NEO4J_URI"] = neo4j_uri os.environ["NEO4J_USERNAME"] = neo4j_username os.environ["NEO4J_PASSWORD"] = neo4j_password - ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage" + rag_storage_config.GRAPH_STORAGE = "Neo4JStorage" # Milvus config milvus_uri = config.get("milvus", "uri", fallback=None) @@ -124,7 +108,7 @@ if milvus_uri: os.environ["MILVUS_USER"] = milvus_user os.environ["MILVUS_PASSWORD"] = milvus_password os.environ["MILVUS_DB_NAME"] = milvus_db_name - ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge" + rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge" # MongoDB config mongo_uri = config.get("mongodb", "uri", fallback=None) @@ -132,8 +116,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None) if mongo_uri: os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_DATABASE"] = mongo_database - ollama_server_infos.KV_STORAGE = "MongoKVStorage" - ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage" + rag_storage_config.KV_STORAGE = "MongoKVStorage" + rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" def get_default_host(binding_type: str) -> str: @@ -535,6 +519,7 @@ def parse_args() -> argparse.Namespace: help="Cosine similarity threshold (default: from env or 0.4)", ) + # Ollama model name parser.add_argument( "--simulated-model-name", type=str, @@ -599,84 +584,13 @@ class DocumentManager: return any(filename.lower().endswith(ext) for ext in self.supported_extensions) -# Pydantic models +# LightRAG query mode class SearchMode(str, Enum): naive = "naive" local = "local" global_ = "global" hybrid = "hybrid" mix = "mix" - bypass = "bypass" - - -class OllamaMessage(BaseModel): - role: str - content: str - images: Optional[List[str]] = None - - -class OllamaChatRequest(BaseModel): - model: str = ollama_server_infos.LIGHTRAG_MODEL - messages: List[OllamaMessage] - stream: bool = True # Default to streaming mode - options: Optional[Dict[str, Any]] = None - system: Optional[str] = None - - -class OllamaChatResponse(BaseModel): - model: str - created_at: str - message: OllamaMessage - done: bool - - -class OllamaGenerateRequest(BaseModel): - model: str = ollama_server_infos.LIGHTRAG_MODEL - prompt: str - system: Optional[str] = None - stream: bool = False - options: Optional[Dict[str, Any]] = None - - -class OllamaGenerateResponse(BaseModel): - model: str - created_at: str - response: str - done: bool - context: Optional[List[int]] - total_duration: Optional[int] - load_duration: Optional[int] - prompt_eval_count: Optional[int] - prompt_eval_duration: Optional[int] - eval_count: Optional[int] - eval_duration: Optional[int] - - -class OllamaVersionResponse(BaseModel): - version: str - - -class OllamaModelDetails(BaseModel): - parent_model: str - format: str - family: str - families: List[str] - parameter_size: str - quantization_level: str - - -class OllamaModel(BaseModel): - name: str - model: str - size: int - digest: str - modified_at: str - details: OllamaModelDetails - - -class OllamaTagResponse(BaseModel): - models: List[OllamaModel] - class QueryRequest(BaseModel): query: str @@ -920,10 +834,10 @@ def create_app(args): if args.llm_binding == "lollms" or args.llm_binding == "ollama" else {}, embedding_func=embedding_func, - kv_storage=ollama_server_infos.KV_STORAGE, - graph_storage=ollama_server_infos.GRAPH_STORAGE, - vector_storage=ollama_server_infos.VECTOR_STORAGE, - doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, + kv_storage=rag_storage_config.KV_STORAGE, + graph_storage=rag_storage_config.GRAPH_STORAGE, + vector_storage=rag_storage_config.VECTOR_STORAGE, + doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -949,10 +863,10 @@ def create_app(args): llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, embedding_func=embedding_func, - kv_storage=ollama_server_infos.KV_STORAGE, - graph_storage=ollama_server_infos.GRAPH_STORAGE, - vector_storage=ollama_server_infos.VECTOR_STORAGE, - doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, + kv_storage=rag_storage_config.KV_STORAGE, + graph_storage=rag_storage_config.GRAPH_STORAGE, + vector_storage=rag_storage_config.VECTOR_STORAGE, + doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -1475,450 +1389,9 @@ def create_app(args): async def get_graphs(label: str): return await rag.get_graps(nodel_label=label, max_depth=100) - # Ollama compatible API endpoints - # ------------------------------------------------- - @app.get("/api/version") - async def get_version(): - """Get Ollama version information""" - return OllamaVersionResponse(version="0.5.4") - - @app.get("/api/tags") - async def get_tags(): - """Retrun available models acting as an Ollama server""" - return OllamaTagResponse( - models=[ - { - "name": ollama_server_infos.LIGHTRAG_MODEL, - "model": ollama_server_infos.LIGHTRAG_MODEL, - "size": ollama_server_infos.LIGHTRAG_SIZE, - "digest": ollama_server_infos.LIGHTRAG_DIGEST, - "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "details": { - "parent_model": "", - "format": "gguf", - "family": ollama_server_infos.LIGHTRAG_NAME, - "families": [ollama_server_infos.LIGHTRAG_NAME], - "parameter_size": "13B", - "quantization_level": "Q4_0", - }, - } - ] - ) - - def parse_query_mode(query: str) -> tuple[str, SearchMode]: - """Parse query prefix to determine search mode - Returns tuple of (cleaned_query, search_mode) - """ - mode_map = { - "/local ": SearchMode.local, - "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword - "/naive ": SearchMode.naive, - "/hybrid ": SearchMode.hybrid, - "/mix ": SearchMode.mix, - "/bypass ": SearchMode.bypass, - } - - for prefix, mode in mode_map.items(): - if query.startswith(prefix): - # After removing prefix an leading spaces - cleaned_query = query[len(prefix) :].lstrip() - return cleaned_query, mode - - return query, SearchMode.hybrid - - @app.post("/api/generate") - async def generate(raw_request: Request, request: OllamaGenerateRequest): - """Handle generate completion requests acting as an Ollama model - For compatiblity purpuse, the request is not processed by LightRAG, - and will be handled by underlying LLM model. - """ - try: - query = request.prompt - start_time = time.time_ns() - prompt_tokens = estimate_tokens(query) - - if request.system: - rag.llm_model_kwargs["system_prompt"] = request.system - - if request.stream: - from fastapi.responses import StreamingResponse - - response = await rag.llm_model_func( - query, stream=True, **rag.llm_model_kwargs - ) - - async def stream_generator(): - try: - first_chunk_time = None - last_chunk_time = None - total_response = "" - - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time - total_response = response - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": response, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - return - - except Exception as e: - logging.error(f"Error in stream_generator: {str(e)}") - raise - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - }, - ) - else: - first_chunk_time = time.time_ns() - response_text = await rag.llm_model_func( - query, stream=False, **rag.llm_model_kwargs - ) - last_chunk_time = time.time_ns() - - if not response_text: - response_text = "No response generated" - - completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - return { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": str(response_text), - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/api/chat") - async def chat(raw_request: Request, request: OllamaChatRequest): - """Process chat completion requests acting as an Ollama model - Routes user queries through LightRAG by selecting query mode based on prefix indicators. - Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. - """ - try: - # Get all messages - messages = request.messages - if not messages: - raise HTTPException(status_code=400, detail="No messages provided") - - # Get the last message as query and previous messages as history - query = messages[-1].content - # Convert OllamaMessage objects to dictionaries - conversation_history = [ - {"role": msg.role, "content": msg.content} for msg in messages[:-1] - ] - - # Check for query prefix - cleaned_query, mode = parse_query_mode(query) - - start_time = time.time_ns() - prompt_tokens = estimate_tokens(cleaned_query) - - param_dict = { - "mode": mode, - "stream": request.stream, - "only_need_context": False, - "conversation_history": conversation_history, - "top_k": args.top_k, - } - - if args.history_turns is not None: - param_dict["history_turns"] = args.history_turns - - query_param = QueryParam(**param_dict) - - if request.stream: - from fastapi.responses import StreamingResponse - - # Determine if the request is prefix with "/bypass" - if mode == SearchMode.bypass: - if request.system: - rag.llm_model_kwargs["system_prompt"] = request.system - response = await rag.llm_model_func( - cleaned_query, - stream=True, - history_messages=conversation_history, - **rag.llm_model_kwargs, - ) - else: - response = await rag.aquery( # Need await to get async generator - cleaned_query, param=query_param - ) - - async def stream_generator(): - first_chunk_time = None - last_chunk_time = None - total_response = "" - - try: - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time - total_response = response - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": response, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - try: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - except (asyncio.CancelledError, Exception) as e: - error_msg = str(e) - if isinstance(e, asyncio.CancelledError): - error_msg = "Stream was cancelled by server" - else: - error_msg = f"Provider error: {error_msg}" - - logging.error(f"Stream error: {error_msg}") - - # Send error message to client - error_data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": f"\n\nError: {error_msg}", - "images": None, - }, - "done": False, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Send final message to close the stream - final_data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return - - if last_chunk_time is not None: - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - except Exception as e: - error_msg = f"Error in stream_generator: {str(e)}" - logging.error(error_msg) - - # Send error message to client - error_data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": {"code": "STREAM_ERROR", "message": error_msg}, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Ensure sending end marker - final_data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - }, - ) - else: - first_chunk_time = time.time_ns() - - # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task - match_result = re.search( - r"\n\nUSER:", cleaned_query, re.MULTILINE - ) - if match_result or mode == SearchMode.bypass: - if request.system: - rag.llm_model_kwargs["system_prompt"] = request.system - - response_text = await rag.llm_model_func( - cleaned_query, - stream=False, - history_messages=conversation_history, - **rag.llm_model_kwargs, - ) - else: - response_text = await rag.aquery(cleaned_query, param=query_param) - - last_chunk_time = time.time_ns() - - if not response_text: - response_text = "No response generated" - - completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - return { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": str(response_text), - "images": None, - }, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) + # Add Ollama API routes + ollama_api = OllamaAPI(rag) + app.include_router(ollama_api.router, prefix="/api") @app.get("/documents", dependencies=[Depends(optional_api_key)]) async def documents(): @@ -1945,10 +1418,10 @@ def create_app(args): "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, "max_tokens": args.max_tokens, - "kv_storage": ollama_server_infos.KV_STORAGE, - "doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE, - "graph_storage": ollama_server_infos.GRAPH_STORAGE, - "vector_storage": ollama_server_infos.VECTOR_STORAGE, + "kv_storage": rag_storage_config.KV_STORAGE, + "doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE, + "graph_storage": rag_storage_config.GRAPH_STORAGE, + "vector_storage": rag_storage_config.VECTOR_STORAGE, }, } diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py new file mode 100644 index 00000000..49ec8414 --- /dev/null +++ b/lightrag/api/ollama_api.py @@ -0,0 +1,554 @@ +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel +from typing import List, Dict, Any, Optional +import logging +import time +import json +import re +import os +from enum import Enum +from fastapi.responses import StreamingResponse +import asyncio +from ascii_colors import trace_exception +from lightrag import LightRAG, QueryParam + +class OllamaServerInfos: + # Constants for emulated Ollama model information + LIGHTRAG_NAME = "lightrag" + LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") + LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" + LIGHTRAG_SIZE = 7365960935 # it's a dummy value + LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" + LIGHTRAG_DIGEST = "sha256:lightrag" + +ollama_server_infos = OllamaServerInfos() + +# query mode according to query prefix (bypass is not LightRAG quer mode) +class SearchMode(str, Enum): + naive = "naive" + local = "local" + global_ = "global" + hybrid = "hybrid" + mix = "mix" + bypass = "bypass" + +class OllamaMessage(BaseModel): + role: str + content: str + images: Optional[List[str]] = None + +class OllamaChatRequest(BaseModel): + model: str + messages: List[OllamaMessage] + stream: bool = True + options: Optional[Dict[str, Any]] = None + system: Optional[str] = None + +class OllamaChatResponse(BaseModel): + model: str + created_at: str + message: OllamaMessage + done: bool + +class OllamaGenerateRequest(BaseModel): + model: str + prompt: str + system: Optional[str] = None + stream: bool = False + options: Optional[Dict[str, Any]] = None + +class OllamaGenerateResponse(BaseModel): + model: str + created_at: str + response: str + done: bool + context: Optional[List[int]] + total_duration: Optional[int] + load_duration: Optional[int] + prompt_eval_count: Optional[int] + prompt_eval_duration: Optional[int] + eval_count: Optional[int] + eval_duration: Optional[int] + +class OllamaVersionResponse(BaseModel): + version: str + +class OllamaModelDetails(BaseModel): + parent_model: str + format: str + family: str + families: List[str] + parameter_size: str + quantization_level: str + +class OllamaModel(BaseModel): + name: str + model: str + size: int + digest: str + modified_at: str + details: OllamaModelDetails + +class OllamaTagResponse(BaseModel): + models: List[OllamaModel] + +def estimate_tokens(text: str) -> int: + """Estimate the number of tokens in text + Chinese characters: approximately 1.5 tokens per character + English characters: approximately 0.25 tokens per character + """ + # Use regex to match Chinese and non-Chinese characters separately + chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) + non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) + + # Calculate estimated token count + tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 + + return int(tokens) + +def parse_query_mode(query: str) -> tuple[str, SearchMode]: + """Parse query prefix to determine search mode + Returns tuple of (cleaned_query, search_mode) + """ + mode_map = { + "/local ": SearchMode.local, + "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword + "/naive ": SearchMode.naive, + "/hybrid ": SearchMode.hybrid, + "/mix ": SearchMode.mix, + "/bypass ": SearchMode.bypass, + } + + for prefix, mode in mode_map.items(): + if query.startswith(prefix): + # After removing prefix an leading spaces + cleaned_query = query[len(prefix) :].lstrip() + return cleaned_query, mode + + return query, SearchMode.hybrid + +class OllamaAPI: + def __init__(self, rag: LightRAG): + self.rag = rag + self.ollama_server_infos = ollama_server_infos + self.router = APIRouter() + self.setup_routes() + + def setup_routes(self): + @self.router.get("/version") + async def get_version(): + """Get Ollama version information""" + return OllamaVersionResponse(version="0.5.4") + + @self.router.get("/tags") + async def get_tags(): + """Return available models acting as an Ollama server""" + return OllamaTagResponse( + models=[ + { + "name": self.ollama_server_infos.LIGHTRAG_MODEL, + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "size": self.ollama_server_infos.LIGHTRAG_SIZE, + "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, + "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "details": { + "parent_model": "", + "format": "gguf", + "family": self.ollama_server_infos.LIGHTRAG_NAME, + "families": [self.ollama_server_infos.LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0", + }, + } + ] + ) + + @self.router.post("/generate") + async def generate(raw_request: Request, request: OllamaGenerateRequest): + """Handle generate completion requests acting as an Ollama model + For compatibility purpose, the request is not processed by LightRAG, + and will be handled by underlying LLM model. + """ + try: + query = request.prompt + start_time = time.time_ns() + prompt_tokens = estimate_tokens(query) + + if request.system: + self.rag.llm_model_kwargs["system_prompt"] = request.system + + if request.stream: + response = await self.rag.llm_model_func( + query, stream=True, **self.rag.llm_model_kwargs + ) + + async def stream_generator(): + try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": response, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return + + except Exception as e: + trace_exception(e) + raise + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + response_text = await self.rag.llm_model_func( + query, stream=False, **self.rag.llm_model_kwargs + ) + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": str(response_text), + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + @self.router.post("/chat") + async def chat(raw_request: Request, request: OllamaChatRequest): + """Process chat completion requests acting as an Ollama model + Routes user queries through LightRAG by selecting query mode based on prefix indicators. + Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. + """ + try: + # Get all messages + messages = request.messages + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") + + # Get the last message as query and previous messages as history + query = messages[-1].content + # Convert OllamaMessage objects to dictionaries + conversation_history = [ + {"role": msg.role, "content": msg.content} for msg in messages[:-1] + ] + + # Check for query prefix + cleaned_query, mode = parse_query_mode(query) + + start_time = time.time_ns() + prompt_tokens = estimate_tokens(cleaned_query) + + param_dict = { + "mode": mode, + "stream": request.stream, + "only_need_context": False, + "conversation_history": conversation_history, + "top_k": self.rag.args.top_k if hasattr(self.rag, 'args') else 50, + } + + if hasattr(self.rag, 'args') and self.rag.args.history_turns is not None: + param_dict["history_turns"] = self.rag.args.history_turns + + query_param = QueryParam(**param_dict) + + if request.stream: + # Determine if the request is prefix with "/bypass" + if mode == SearchMode.bypass: + if request.system: + self.rag.llm_model_kwargs["system_prompt"] = request.system + response = await self.rag.llm_model_func( + cleaned_query, + stream=True, + history_messages=conversation_history, + **self.rag.llm_model_kwargs, + ) + else: + response = await self.rag.aquery( + cleaned_query, param=query_param + ) + + async def stream_generator(): + first_chunk_time = None + last_chunk_time = None + total_response = "" + + try: + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": response, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logging.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": f"\n\nError: {error_msg}", + "images": None, + }, + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Send final message to close the stream + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + + if last_chunk_time is not None: + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + except Exception as e: + error_msg = f"Error in stream_generator: {str(e)}" + logging.error(error_msg) + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "error": {"code": "STREAM_ERROR", "message": error_msg}, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Ensure sending end marker + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + + # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task + match_result = re.search( + r"\n\nUSER:", cleaned_query, re.MULTILINE + ) + if match_result or mode == SearchMode.bypass: + if request.system: + self.rag.llm_model_kwargs["system_prompt"] = request.system + + response_text = await self.rag.llm_model_func( + cleaned_query, + stream=False, + history_messages=conversation_history, + **self.rag.llm_model_kwargs, + ) + else: + response_text = await self.rag.aquery(cleaned_query, param=query_param) + + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": str(response_text), + "images": None, + }, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) From 1a61d9ee7f144053e3291c630611b349fa7dbefc Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Feb 2025 22:29:07 +0800 Subject: [PATCH 9/9] Fix linting --- lightrag/api/lightrag_server.py | 3 +++ lightrag/api/ollama_api.py | 26 +++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 2ee2838e..fc6f1580 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -41,12 +41,14 @@ from .ollama_api import ollama_server_infos # Load environment variables load_dotenv() + class RAGStorageConfig: KV_STORAGE = "JsonKVStorage" DOC_STATUS_STORAGE = "JsonDocStatusStorage" GRAPH_STORAGE = "NetworkXStorage" VECTOR_STORAGE = "NanoVectorDBStorage" + # Initialize rag storage config rag_storage_config = RAGStorageConfig() @@ -592,6 +594,7 @@ class SearchMode(str, Enum): hybrid = "hybrid" mix = "mix" + class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 49ec8414..e2637db0 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -12,6 +12,7 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam + class OllamaServerInfos: # Constants for emulated Ollama model information LIGHTRAG_NAME = "lightrag" @@ -21,8 +22,10 @@ class OllamaServerInfos: LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" LIGHTRAG_DIGEST = "sha256:lightrag" + ollama_server_infos = OllamaServerInfos() + # query mode according to query prefix (bypass is not LightRAG quer mode) class SearchMode(str, Enum): naive = "naive" @@ -32,11 +35,13 @@ class SearchMode(str, Enum): mix = "mix" bypass = "bypass" + class OllamaMessage(BaseModel): role: str content: str images: Optional[List[str]] = None + class OllamaChatRequest(BaseModel): model: str messages: List[OllamaMessage] @@ -44,12 +49,14 @@ class OllamaChatRequest(BaseModel): options: Optional[Dict[str, Any]] = None system: Optional[str] = None + class OllamaChatResponse(BaseModel): model: str created_at: str message: OllamaMessage done: bool + class OllamaGenerateRequest(BaseModel): model: str prompt: str @@ -57,6 +64,7 @@ class OllamaGenerateRequest(BaseModel): stream: bool = False options: Optional[Dict[str, Any]] = None + class OllamaGenerateResponse(BaseModel): model: str created_at: str @@ -70,9 +78,11 @@ class OllamaGenerateResponse(BaseModel): eval_count: Optional[int] eval_duration: Optional[int] + class OllamaVersionResponse(BaseModel): version: str + class OllamaModelDetails(BaseModel): parent_model: str format: str @@ -81,6 +91,7 @@ class OllamaModelDetails(BaseModel): parameter_size: str quantization_level: str + class OllamaModel(BaseModel): name: str model: str @@ -89,9 +100,11 @@ class OllamaModel(BaseModel): modified_at: str details: OllamaModelDetails + class OllamaTagResponse(BaseModel): models: List[OllamaModel] + def estimate_tokens(text: str) -> int: """Estimate the number of tokens in text Chinese characters: approximately 1.5 tokens per character @@ -106,6 +119,7 @@ def estimate_tokens(text: str) -> int: return int(tokens) + def parse_query_mode(query: str) -> tuple[str, SearchMode]: """Parse query prefix to determine search mode Returns tuple of (cleaned_query, search_mode) @@ -127,6 +141,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]: return query, SearchMode.hybrid + class OllamaAPI: def __init__(self, rag: LightRAG): self.rag = rag @@ -333,10 +348,13 @@ class OllamaAPI: "stream": request.stream, "only_need_context": False, "conversation_history": conversation_history, - "top_k": self.rag.args.top_k if hasattr(self.rag, 'args') else 50, + "top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50, } - if hasattr(self.rag, 'args') and self.rag.args.history_turns is not None: + if ( + hasattr(self.rag, "args") + and self.rag.args.history_turns is not None + ): param_dict["history_turns"] = self.rag.args.history_turns query_param = QueryParam(**param_dict) @@ -521,7 +539,9 @@ class OllamaAPI: **self.rag.llm_model_kwargs, ) else: - response_text = await self.rag.aquery(cleaned_query, param=query_param) + response_text = await self.rag.aquery( + cleaned_query, param=query_param + ) last_chunk_time = time.time_ns()