From 7f5f44a646b524a894729004e75d16ffa1a68b7f Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 19:24:12 +0100 Subject: [PATCH] added type validation --- lightrag/api/lightrag_server.py | 41 ++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index e2160b12..07ba6505 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -14,7 +14,7 @@ from fastapi.staticfiles import StaticFiles import logging import argparse from typing import List, Any, Literal, Optional, Dict -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG, QueryParam from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ @@ -629,42 +629,42 @@ class QueryRequest(BaseModel): query: str """Specifies the retrieval mode""" - mode: Literal["local", "global", "hybrid", "naive", "mix"] = "hybrid" + mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(default="hybrid") """If True, only returns the retrieved context without generating a response.""" - only_need_context: Optional[bool] = None + only_need_context: Optional[bool] = Field(default=None) """If True, only returns the generated prompt without producing a response.""" - only_need_prompt: Optional[bool] = None + only_need_prompt: Optional[bool] = Field(default=None) """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" - response_type: Optional[str] = None + response_type: Optional[str] = Field(default=None) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" - top_k: Optional[int] = None + top_k: Optional[int] = Field(default=None) """Maximum number of tokens allowed for each retrieved text chunk.""" - max_token_for_text_unit: Optional[int] = None + max_token_for_text_unit: Optional[int] = Field(default=None) """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" - max_token_for_global_context: Optional[int] = None + max_token_for_global_context: Optional[int] = Field(default=None) """Maximum number of tokens allocated for entity descriptions in local retrieval.""" - max_token_for_local_context: Optional[int] = None + max_token_for_local_context: Optional[int] = Field(default=None) """List of high-level keywords to prioritize in retrieval.""" - hl_keywords: Optional[List[str]] = None + hl_keywords: Optional[List[str]] = Field(default=None) """List of low-level keywords to refine retrieval focus.""" - ll_keywords: Optional[List[str]] = None + ll_keywords: Optional[List[str]] = Field(default=None) """Stores past conversation history to maintain context. Format: [{"role": "user/assistant", "content": "message"}]. """ - conversation_history: Optional[List[dict[str, Any]]] = None + conversation_history: Optional[List[dict[str, Any]]] = Field(default=None) """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" - history_turns: Optional[int] = None + history_turns: Optional[int] = Field(default=None) class QueryResponse(BaseModel): @@ -674,9 +674,22 @@ class QueryResponse(BaseModel): class InsertTextRequest(BaseModel): text: str + @field_validator('text', mode='after') + @classmethod + def check_not_empty(cls, text: str) -> str: + if not text: + raise ValueError("Text cannot be empty") + return text class InsertTextsRequest(BaseModel): - texts: list[str] + texts: list[str] = Field(default_factory=list) + + @field_validator('texts', mode='after') + @classmethod + def check_not_empty(cls, texts: list[str]) -> list[str]: + if not texts: + raise ValueError("Texts cannot be empty") + return texts class InsertResponse(BaseModel):