diff --git a/README.md b/README.md index ee5519f5..06e914ba 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News +- [x] [2025.02.05]🎯📢Our team has released [VideoRAG](https://github.com/HKUDS/VideoRAG) for processing and understanding extremely long-context videos. - [x] [2025.01.13]🎯📢Our team has released [MiniRAG](https://github.com/HKUDS/MiniRAG) making RAG simpler with small models. - [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage). - [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 27ce8a8d..fc6f1580 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -4,28 +4,20 @@ from fastapi import ( File, UploadFile, Form, - Request, BackgroundTasks, ) -# Backend (Python) -# Add this to store progress globally -from typing import Dict import threading - -import json import os - +import json +import re from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel import logging import argparse -import time -import re -from typing import List, Any, Optional, Union +from typing import List, Any, Optional, Union, Dict +from pydantic import BaseModel from lightrag import LightRAG, QueryParam from lightrag.api import __api_version__ - from lightrag.utils import EmbeddingFunc from enum import Enum from pathlib import Path @@ -34,20 +26,32 @@ import aiofiles from ascii_colors import trace_exception, ASCIIColors import sys import configparser - from fastapi import Depends, Security from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager - from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm - from dotenv import load_dotenv +from .ollama_api import ( + OllamaAPI, +) +from .ollama_api import ollama_server_infos # Load environment variables load_dotenv() + +class RAGStorageConfig: + KV_STORAGE = "JsonKVStorage" + DOC_STATUS_STORAGE = "JsonDocStatusStorage" + GRAPH_STORAGE = "NetworkXStorage" + VECTOR_STORAGE = "NanoVectorDBStorage" + + +# Initialize rag storage config +rag_storage_config = RAGStorageConfig() + # Global progress tracker scan_progress: Dict = { "is_scanning": False, @@ -76,24 +80,6 @@ def estimate_tokens(text: str) -> int: return int(tokens) -class OllamaServerInfos: - # Constants for emulated Ollama model information - LIGHTRAG_NAME = "lightrag" - LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") - LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" - LIGHTRAG_SIZE = 7365960935 # it's a dummy value - LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" - LIGHTRAG_DIGEST = "sha256:lightrag" - - KV_STORAGE = "JsonKVStorage" - DOC_STATUS_STORAGE = "JsonDocStatusStorage" - GRAPH_STORAGE = "NetworkXStorage" - VECTOR_STORAGE = "NanoVectorDBStorage" - - -# Add infos -ollama_server_infos = OllamaServerInfos() - # read config.ini config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -101,8 +87,8 @@ config.read("config.ini", "utf-8") redis_uri = config.get("redis", "uri", fallback=None) if redis_uri: os.environ["REDIS_URI"] = redis_uri - ollama_server_infos.KV_STORAGE = "RedisKVStorage" - ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage" + rag_storage_config.KV_STORAGE = "RedisKVStorage" + rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage" # Neo4j config neo4j_uri = config.get("neo4j", "uri", fallback=None) @@ -112,7 +98,7 @@ if neo4j_uri: os.environ["NEO4J_URI"] = neo4j_uri os.environ["NEO4J_USERNAME"] = neo4j_username os.environ["NEO4J_PASSWORD"] = neo4j_password - ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage" + rag_storage_config.GRAPH_STORAGE = "Neo4JStorage" # Milvus config milvus_uri = config.get("milvus", "uri", fallback=None) @@ -124,7 +110,7 @@ if milvus_uri: os.environ["MILVUS_USER"] = milvus_user os.environ["MILVUS_PASSWORD"] = milvus_password os.environ["MILVUS_DB_NAME"] = milvus_db_name - ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge" + rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge" # MongoDB config mongo_uri = config.get("mongodb", "uri", fallback=None) @@ -132,8 +118,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None) if mongo_uri: os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_DATABASE"] = mongo_database - ollama_server_infos.KV_STORAGE = "MongoKVStorage" - ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage" + rag_storage_config.KV_STORAGE = "MongoKVStorage" + rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" def get_default_host(binding_type: str) -> str: @@ -535,6 +521,7 @@ def parse_args() -> argparse.Namespace: help="Cosine similarity threshold (default: from env or 0.4)", ) + # Ollama model name parser.add_argument( "--simulated-model-name", type=str, @@ -599,83 +586,13 @@ class DocumentManager: return any(filename.lower().endswith(ext) for ext in self.supported_extensions) -# Pydantic models +# LightRAG query mode class SearchMode(str, Enum): naive = "naive" local = "local" global_ = "global" hybrid = "hybrid" mix = "mix" - bypass = "bypass" - - -class OllamaMessage(BaseModel): - role: str - content: str - images: Optional[List[str]] = None - - -class OllamaChatRequest(BaseModel): - model: str = ollama_server_infos.LIGHTRAG_MODEL - messages: List[OllamaMessage] - stream: bool = True # Default to streaming mode - options: Optional[Dict[str, Any]] = None - system: Optional[str] = None - - -class OllamaChatResponse(BaseModel): - model: str - created_at: str - message: OllamaMessage - done: bool - - -class OllamaGenerateRequest(BaseModel): - model: str = ollama_server_infos.LIGHTRAG_MODEL - prompt: str - system: Optional[str] = None - stream: bool = False - options: Optional[Dict[str, Any]] = None - - -class OllamaGenerateResponse(BaseModel): - model: str - created_at: str - response: str - done: bool - context: Optional[List[int]] - total_duration: Optional[int] - load_duration: Optional[int] - prompt_eval_count: Optional[int] - prompt_eval_duration: Optional[int] - eval_count: Optional[int] - eval_duration: Optional[int] - - -class OllamaVersionResponse(BaseModel): - version: str - - -class OllamaModelDetails(BaseModel): - parent_model: str - format: str - family: str - families: List[str] - parameter_size: str - quantization_level: str - - -class OllamaModel(BaseModel): - name: str - model: str - size: int - digest: str - modified_at: str - details: OllamaModelDetails - - -class OllamaTagResponse(BaseModel): - models: List[OllamaModel] class QueryRequest(BaseModel): @@ -920,10 +837,10 @@ def create_app(args): if args.llm_binding == "lollms" or args.llm_binding == "ollama" else {}, embedding_func=embedding_func, - kv_storage=ollama_server_infos.KV_STORAGE, - graph_storage=ollama_server_infos.GRAPH_STORAGE, - vector_storage=ollama_server_infos.VECTOR_STORAGE, - doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, + kv_storage=rag_storage_config.KV_STORAGE, + graph_storage=rag_storage_config.GRAPH_STORAGE, + vector_storage=rag_storage_config.VECTOR_STORAGE, + doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -949,10 +866,10 @@ def create_app(args): llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, embedding_func=embedding_func, - kv_storage=ollama_server_infos.KV_STORAGE, - graph_storage=ollama_server_infos.GRAPH_STORAGE, - vector_storage=ollama_server_infos.VECTOR_STORAGE, - doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, + kv_storage=rag_storage_config.KV_STORAGE, + graph_storage=rag_storage_config.GRAPH_STORAGE, + vector_storage=rag_storage_config.VECTOR_STORAGE, + doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -1475,401 +1392,9 @@ def create_app(args): async def get_graphs(label: str): return await rag.get_graps(nodel_label=label, max_depth=100) - # Ollama compatible API endpoints - # ------------------------------------------------- - @app.get("/api/version") - async def get_version(): - """Get Ollama version information""" - return OllamaVersionResponse(version="0.5.4") - - @app.get("/api/tags") - async def get_tags(): - """Retrun available models acting as an Ollama server""" - return OllamaTagResponse( - models=[ - { - "name": ollama_server_infos.LIGHTRAG_MODEL, - "model": ollama_server_infos.LIGHTRAG_MODEL, - "size": ollama_server_infos.LIGHTRAG_SIZE, - "digest": ollama_server_infos.LIGHTRAG_DIGEST, - "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "details": { - "parent_model": "", - "format": "gguf", - "family": ollama_server_infos.LIGHTRAG_NAME, - "families": [ollama_server_infos.LIGHTRAG_NAME], - "parameter_size": "13B", - "quantization_level": "Q4_0", - }, - } - ] - ) - - def parse_query_mode(query: str) -> tuple[str, SearchMode]: - """Parse query prefix to determine search mode - Returns tuple of (cleaned_query, search_mode) - """ - mode_map = { - "/local ": SearchMode.local, - "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword - "/naive ": SearchMode.naive, - "/hybrid ": SearchMode.hybrid, - "/mix ": SearchMode.mix, - "/bypass ": SearchMode.bypass, - } - - for prefix, mode in mode_map.items(): - if query.startswith(prefix): - # After removing prefix an leading spaces - cleaned_query = query[len(prefix) :].lstrip() - return cleaned_query, mode - - return query, SearchMode.hybrid - - @app.post("/api/generate") - async def generate(raw_request: Request, request: OllamaGenerateRequest): - """Handle generate completion requests acting as an Ollama model - For compatiblity purpuse, the request is not processed by LightRAG, - and will be handled by underlying LLM model. - """ - try: - query = request.prompt - start_time = time.time_ns() - prompt_tokens = estimate_tokens(query) - - if request.system: - rag.llm_model_kwargs["system_prompt"] = request.system - - if request.stream: - from fastapi.responses import StreamingResponse - - response = await rag.llm_model_func( - query, stream=True, **rag.llm_model_kwargs - ) - - async def stream_generator(): - try: - first_chunk_time = None - last_chunk_time = None - total_response = "" - - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time - total_response = response - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": response, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - return - - except Exception as e: - logging.error(f"Error in stream_generator: {str(e)}") - raise - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - }, - ) - else: - first_chunk_time = time.time_ns() - response_text = await rag.llm_model_func( - query, stream=False, **rag.llm_model_kwargs - ) - last_chunk_time = time.time_ns() - - if not response_text: - response_text = "No response generated" - - completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - return { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": str(response_text), - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/api/chat") - async def chat(raw_request: Request, request: OllamaChatRequest): - """Process chat completion requests acting as an Ollama model - Routes user queries through LightRAG by selecting query mode based on prefix indicators. - Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. - """ - try: - # Get all messages - messages = request.messages - if not messages: - raise HTTPException(status_code=400, detail="No messages provided") - - # Get the last message as query and previous messages as history - query = messages[-1].content - # Convert OllamaMessage objects to dictionaries - conversation_history = [ - {"role": msg.role, "content": msg.content} for msg in messages[:-1] - ] - - # Check for query prefix - cleaned_query, mode = parse_query_mode(query) - - start_time = time.time_ns() - prompt_tokens = estimate_tokens(cleaned_query) - - param_dict = { - "mode": mode, - "stream": request.stream, - "only_need_context": False, - "conversation_history": conversation_history, - "top_k": args.top_k, - } - - if args.history_turns is not None: - param_dict["history_turns"] = args.history_turns - - query_param = QueryParam(**param_dict) - - if request.stream: - from fastapi.responses import StreamingResponse - - # Determine if the request is prefix with "/bypass" - if mode == SearchMode.bypass: - if request.system: - rag.llm_model_kwargs["system_prompt"] = request.system - response = await rag.llm_model_func( - cleaned_query, - stream=True, - history_messages=conversation_history, - **rag.llm_model_kwargs, - ) - else: - response = await rag.aquery( # Need await to get async generator - cleaned_query, param=query_param - ) - - async def stream_generator(): - try: - first_chunk_time = None - last_chunk_time = None - total_response = "" - - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time - total_response = response - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": response, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - return # Ensure the generator ends immediately after sending the completion marker - except Exception as e: - logging.error(f"Error in stream_generator: {str(e)}") - raise - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - }, - ) - else: - first_chunk_time = time.time_ns() - - # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task - match_result = re.search( - r"\n\nUSER:", cleaned_query, re.MULTILINE - ) - if match_result or mode == SearchMode.bypass: - if request.system: - rag.llm_model_kwargs["system_prompt"] = request.system - - response_text = await rag.llm_model_func( - cleaned_query, - stream=False, - history_messages=conversation_history, - **rag.llm_model_kwargs, - ) - else: - response_text = await rag.aquery(cleaned_query, param=query_param) - - last_chunk_time = time.time_ns() - - if not response_text: - response_text = "No response generated" - - completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - return { - "model": ollama_server_infos.LIGHTRAG_MODEL, - "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": str(response_text), - "images": None, - }, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - except Exception as e: - trace_exception(e) - raise HTTPException(status_code=500, detail=str(e)) + # Add Ollama API routes + ollama_api = OllamaAPI(rag) + app.include_router(ollama_api.router, prefix="/api") @app.get("/documents", dependencies=[Depends(optional_api_key)]) async def documents(): @@ -1896,10 +1421,10 @@ def create_app(args): "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, "max_tokens": args.max_tokens, - "kv_storage": ollama_server_infos.KV_STORAGE, - "doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE, - "graph_storage": ollama_server_infos.GRAPH_STORAGE, - "vector_storage": ollama_server_infos.VECTOR_STORAGE, + "kv_storage": rag_storage_config.KV_STORAGE, + "doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE, + "graph_storage": rag_storage_config.GRAPH_STORAGE, + "vector_storage": rag_storage_config.VECTOR_STORAGE, }, } diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py new file mode 100644 index 00000000..e2637db0 --- /dev/null +++ b/lightrag/api/ollama_api.py @@ -0,0 +1,574 @@ +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel +from typing import List, Dict, Any, Optional +import logging +import time +import json +import re +import os +from enum import Enum +from fastapi.responses import StreamingResponse +import asyncio +from ascii_colors import trace_exception +from lightrag import LightRAG, QueryParam + + +class OllamaServerInfos: + # Constants for emulated Ollama model information + LIGHTRAG_NAME = "lightrag" + LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") + LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" + LIGHTRAG_SIZE = 7365960935 # it's a dummy value + LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" + LIGHTRAG_DIGEST = "sha256:lightrag" + + +ollama_server_infos = OllamaServerInfos() + + +# query mode according to query prefix (bypass is not LightRAG quer mode) +class SearchMode(str, Enum): + naive = "naive" + local = "local" + global_ = "global" + hybrid = "hybrid" + mix = "mix" + bypass = "bypass" + + +class OllamaMessage(BaseModel): + role: str + content: str + images: Optional[List[str]] = None + + +class OllamaChatRequest(BaseModel): + model: str + messages: List[OllamaMessage] + stream: bool = True + options: Optional[Dict[str, Any]] = None + system: Optional[str] = None + + +class OllamaChatResponse(BaseModel): + model: str + created_at: str + message: OllamaMessage + done: bool + + +class OllamaGenerateRequest(BaseModel): + model: str + prompt: str + system: Optional[str] = None + stream: bool = False + options: Optional[Dict[str, Any]] = None + + +class OllamaGenerateResponse(BaseModel): + model: str + created_at: str + response: str + done: bool + context: Optional[List[int]] + total_duration: Optional[int] + load_duration: Optional[int] + prompt_eval_count: Optional[int] + prompt_eval_duration: Optional[int] + eval_count: Optional[int] + eval_duration: Optional[int] + + +class OllamaVersionResponse(BaseModel): + version: str + + +class OllamaModelDetails(BaseModel): + parent_model: str + format: str + family: str + families: List[str] + parameter_size: str + quantization_level: str + + +class OllamaModel(BaseModel): + name: str + model: str + size: int + digest: str + modified_at: str + details: OllamaModelDetails + + +class OllamaTagResponse(BaseModel): + models: List[OllamaModel] + + +def estimate_tokens(text: str) -> int: + """Estimate the number of tokens in text + Chinese characters: approximately 1.5 tokens per character + English characters: approximately 0.25 tokens per character + """ + # Use regex to match Chinese and non-Chinese characters separately + chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) + non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) + + # Calculate estimated token count + tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 + + return int(tokens) + + +def parse_query_mode(query: str) -> tuple[str, SearchMode]: + """Parse query prefix to determine search mode + Returns tuple of (cleaned_query, search_mode) + """ + mode_map = { + "/local ": SearchMode.local, + "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword + "/naive ": SearchMode.naive, + "/hybrid ": SearchMode.hybrid, + "/mix ": SearchMode.mix, + "/bypass ": SearchMode.bypass, + } + + for prefix, mode in mode_map.items(): + if query.startswith(prefix): + # After removing prefix an leading spaces + cleaned_query = query[len(prefix) :].lstrip() + return cleaned_query, mode + + return query, SearchMode.hybrid + + +class OllamaAPI: + def __init__(self, rag: LightRAG): + self.rag = rag + self.ollama_server_infos = ollama_server_infos + self.router = APIRouter() + self.setup_routes() + + def setup_routes(self): + @self.router.get("/version") + async def get_version(): + """Get Ollama version information""" + return OllamaVersionResponse(version="0.5.4") + + @self.router.get("/tags") + async def get_tags(): + """Return available models acting as an Ollama server""" + return OllamaTagResponse( + models=[ + { + "name": self.ollama_server_infos.LIGHTRAG_MODEL, + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "size": self.ollama_server_infos.LIGHTRAG_SIZE, + "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, + "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "details": { + "parent_model": "", + "format": "gguf", + "family": self.ollama_server_infos.LIGHTRAG_NAME, + "families": [self.ollama_server_infos.LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0", + }, + } + ] + ) + + @self.router.post("/generate") + async def generate(raw_request: Request, request: OllamaGenerateRequest): + """Handle generate completion requests acting as an Ollama model + For compatibility purpose, the request is not processed by LightRAG, + and will be handled by underlying LLM model. + """ + try: + query = request.prompt + start_time = time.time_ns() + prompt_tokens = estimate_tokens(query) + + if request.system: + self.rag.llm_model_kwargs["system_prompt"] = request.system + + if request.stream: + response = await self.rag.llm_model_func( + query, stream=True, **self.rag.llm_model_kwargs + ) + + async def stream_generator(): + try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": response, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return + + except Exception as e: + trace_exception(e) + raise + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + response_text = await self.rag.llm_model_func( + query, stream=False, **self.rag.llm_model_kwargs + ) + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": str(response_text), + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + @self.router.post("/chat") + async def chat(raw_request: Request, request: OllamaChatRequest): + """Process chat completion requests acting as an Ollama model + Routes user queries through LightRAG by selecting query mode based on prefix indicators. + Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. + """ + try: + # Get all messages + messages = request.messages + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") + + # Get the last message as query and previous messages as history + query = messages[-1].content + # Convert OllamaMessage objects to dictionaries + conversation_history = [ + {"role": msg.role, "content": msg.content} for msg in messages[:-1] + ] + + # Check for query prefix + cleaned_query, mode = parse_query_mode(query) + + start_time = time.time_ns() + prompt_tokens = estimate_tokens(cleaned_query) + + param_dict = { + "mode": mode, + "stream": request.stream, + "only_need_context": False, + "conversation_history": conversation_history, + "top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50, + } + + if ( + hasattr(self.rag, "args") + and self.rag.args.history_turns is not None + ): + param_dict["history_turns"] = self.rag.args.history_turns + + query_param = QueryParam(**param_dict) + + if request.stream: + # Determine if the request is prefix with "/bypass" + if mode == SearchMode.bypass: + if request.system: + self.rag.llm_model_kwargs["system_prompt"] = request.system + response = await self.rag.llm_model_func( + cleaned_query, + stream=True, + history_messages=conversation_history, + **self.rag.llm_model_kwargs, + ) + else: + response = await self.rag.aquery( + cleaned_query, param=query_param + ) + + async def stream_generator(): + first_chunk_time = None + last_chunk_time = None + total_response = "" + + try: + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": response, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logging.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": f"\n\nError: {error_msg}", + "images": None, + }, + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Send final message to close the stream + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + + if last_chunk_time is not None: + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + except Exception as e: + error_msg = f"Error in stream_generator: {str(e)}" + logging.error(error_msg) + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "error": {"code": "STREAM_ERROR", "message": error_msg}, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Ensure sending end marker + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + + # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task + match_result = re.search( + r"\n\nUSER:", cleaned_query, re.MULTILINE + ) + if match_result or mode == SearchMode.bypass: + if request.system: + self.rag.llm_model_kwargs["system_prompt"] = request.system + + response_text = await self.rag.llm_model_func( + cleaned_query, + stream=False, + history_messages=conversation_history, + **self.rag.llm_model_kwargs, + ) + else: + response_text = await self.rag.aquery( + cleaned_query, param=query_param + ) + + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": str(response_text), + "images": None, + }, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + except Exception as e: + trace_exception(e) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3014f737..420b82eb 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -372,12 +372,23 @@ class LightRAG: # 3. Filter out already processed documents # _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) - _add_doc_keys = { - doc_id - for doc_id in new_docs.keys() - if (current_doc := await self.doc_status.get_by_id(doc_id)) is None - or current_doc.status == DocStatus.FAILED - } + _add_doc_keys = set() + for doc_id in new_docs.keys(): + current_doc = await self.doc_status.get_by_id(doc_id) + + if current_doc is None: + _add_doc_keys.add(doc_id) + continue # skip to the next doc_id + + status = None + if isinstance(current_doc, dict): + status = current_doc["status"] + else: + status = current_doc.status + + if status == DocStatus.FAILED: + _add_doc_keys.add(doc_id) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if not new_docs: diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 11ba69c0..4eaca093 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -125,13 +125,17 @@ async def openai_complete_if_cache( if hasattr(response, "__aiter__"): async def inner(): - async for chunk in response: - content = chunk.choices[0].delta.content - if content is None: - continue - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - yield content + try: + async for chunk in response: + content = chunk.choices[0].delta.content + if content is None: + continue + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + except Exception as e: + logger.error(f"Error in stream response: {str(e)}") + raise return inner() else: