diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 2ee2838e..fc6f1580 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -41,12 +41,14 @@ from .ollama_api import ollama_server_infos # Load environment variables load_dotenv() + class RAGStorageConfig: KV_STORAGE = "JsonKVStorage" DOC_STATUS_STORAGE = "JsonDocStatusStorage" GRAPH_STORAGE = "NetworkXStorage" VECTOR_STORAGE = "NanoVectorDBStorage" + # Initialize rag storage config rag_storage_config = RAGStorageConfig() @@ -592,6 +594,7 @@ class SearchMode(str, Enum): hybrid = "hybrid" mix = "mix" + class QueryRequest(BaseModel): query: str mode: SearchMode = SearchMode.hybrid diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 49ec8414..e2637db0 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -12,6 +12,7 @@ import asyncio from ascii_colors import trace_exception from lightrag import LightRAG, QueryParam + class OllamaServerInfos: # Constants for emulated Ollama model information LIGHTRAG_NAME = "lightrag" @@ -21,8 +22,10 @@ class OllamaServerInfos: LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" LIGHTRAG_DIGEST = "sha256:lightrag" + ollama_server_infos = OllamaServerInfos() + # query mode according to query prefix (bypass is not LightRAG quer mode) class SearchMode(str, Enum): naive = "naive" @@ -32,11 +35,13 @@ class SearchMode(str, Enum): mix = "mix" bypass = "bypass" + class OllamaMessage(BaseModel): role: str content: str images: Optional[List[str]] = None + class OllamaChatRequest(BaseModel): model: str messages: List[OllamaMessage] @@ -44,12 +49,14 @@ class OllamaChatRequest(BaseModel): options: Optional[Dict[str, Any]] = None system: Optional[str] = None + class OllamaChatResponse(BaseModel): model: str created_at: str message: OllamaMessage done: bool + class OllamaGenerateRequest(BaseModel): model: str prompt: str @@ -57,6 +64,7 @@ class OllamaGenerateRequest(BaseModel): stream: bool = False options: Optional[Dict[str, Any]] = None + class OllamaGenerateResponse(BaseModel): model: str created_at: str @@ -70,9 +78,11 @@ class OllamaGenerateResponse(BaseModel): eval_count: Optional[int] eval_duration: Optional[int] + class OllamaVersionResponse(BaseModel): version: str + class OllamaModelDetails(BaseModel): parent_model: str format: str @@ -81,6 +91,7 @@ class OllamaModelDetails(BaseModel): parameter_size: str quantization_level: str + class OllamaModel(BaseModel): name: str model: str @@ -89,9 +100,11 @@ class OllamaModel(BaseModel): modified_at: str details: OllamaModelDetails + class OllamaTagResponse(BaseModel): models: List[OllamaModel] + def estimate_tokens(text: str) -> int: """Estimate the number of tokens in text Chinese characters: approximately 1.5 tokens per character @@ -106,6 +119,7 @@ def estimate_tokens(text: str) -> int: return int(tokens) + def parse_query_mode(query: str) -> tuple[str, SearchMode]: """Parse query prefix to determine search mode Returns tuple of (cleaned_query, search_mode) @@ -127,6 +141,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]: return query, SearchMode.hybrid + class OllamaAPI: def __init__(self, rag: LightRAG): self.rag = rag @@ -333,10 +348,13 @@ class OllamaAPI: "stream": request.stream, "only_need_context": False, "conversation_history": conversation_history, - "top_k": self.rag.args.top_k if hasattr(self.rag, 'args') else 50, + "top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50, } - if hasattr(self.rag, 'args') and self.rag.args.history_turns is not None: + if ( + hasattr(self.rag, "args") + and self.rag.args.history_turns is not None + ): param_dict["history_turns"] = self.rag.args.history_turns query_param = QueryParam(**param_dict) @@ -521,7 +539,9 @@ class OllamaAPI: **self.rag.llm_model_kwargs, ) else: - response_text = await self.rag.aquery(cleaned_query, param=query_param) + response_text = await self.rag.aquery( + cleaned_query, param=query_param + ) last_chunk_time = time.time_ns()