From 3119f76e26197c169a0fc42fb1ed33438c24d604 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 19:42:09 +0100 Subject: [PATCH] added field check --- lightrag/api/lightrag_server.py | 157 +++++++++++++++++++++++--------- 1 file changed, 115 insertions(+), 42 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 29146f59..4ce8ddb3 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -626,55 +626,125 @@ class DocumentManager: class QueryRequest(BaseModel): - query: str - - """Specifies the retrieval mode""" - mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(default="hybrid") - - """If True, only returns the retrieved context without generating a response.""" - only_need_context: Optional[bool] = Field(default=None) - - """If True, only returns the generated prompt without producing a response.""" - only_need_prompt: Optional[bool] = Field(default=None) - - """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" - response_type: Optional[str] = Field(min_length=1, default=None) - - """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" - top_k: Optional[int] = Field(gt=1, default=None) - - """Maximum number of tokens allowed for each retrieved text chunk.""" - max_token_for_text_unit: Optional[int] = Field(gt=1, default=None) - - """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" - max_token_for_global_context: Optional[int] = Field(gt=1, default=None) - - """Maximum number of tokens allocated for entity descriptions in local retrieval.""" - max_token_for_local_context: Optional[int] = Field(gt=1, default=None) - - """List of high-level keywords to prioritize in retrieval.""" - hl_keywords: Optional[List[str]] = Field(min_length=1, default=None) - - """List of low-level keywords to refine retrieval focus.""" - ll_keywords: Optional[List[str]] = Field(min_length=1, default=None) - - """Stores past conversation history to maintain context. - Format: [{"role": "user/assistant", "content": "message"}]. - """ - conversation_history: Optional[List[dict[str, Any]]] = Field( - min_length=1, default=None + query: str = Field( + min_length=1, + description="The query text", ) - """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" - history_turns: Optional[int] = Field(gt=1, default=None) + mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field( + default="hybrid", + description="Query mode", + ) + + only_need_context: Optional[bool] = Field( + default=None, + description="If True, only returns the retrieved context without generating a response.", + ) + + only_need_prompt: Optional[bool] = Field( + default=None, + description="If True, only returns the generated prompt without producing a response.", + ) + + response_type: Optional[str] = Field( + min_length=1, + default=None, + description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", + ) + + top_k: Optional[int] = Field( + gt=1, + default=None, + description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", + ) + + max_token_for_text_unit: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allowed for each retrieved text chunk.", + ) + + max_token_for_global_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", + ) + + max_token_for_local_context: Optional[int] = Field( + gt=1, + default=None, + description="Maximum number of tokens allocated for entity descriptions in local retrieval.", + ) + + hl_keywords: Optional[List[str]] = Field( + min_length=1, + default=None, + description="List of high-level keywords to prioritize in retrieval.", + ) + + ll_keywords: Optional[List[str]] = Field( + min_length=1, + default=None, + description="List of low-level keywords to refine retrieval focus.", + ) + + conversation_history: Optional[List[dict[str, Any]]] = Field( + min_length=1, + default=None, + description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", + ) + + history_turns: Optional[int] = Field( + gt=1, + default=None, + description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", + ) + + @field_validator("query", mode="after") + @classmethod + def query_strip_after(cls, query: str) -> str: + return query.strip() + + @field_validator("hl_keywords", mode="after") + @classmethod + def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: + if hl_keywords is None: + return None + return [keyword.strip() for keyword in hl_keywords] + + @field_validator("ll_keywords", mode="after") + @classmethod + def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: + if ll_keywords is None: + return None + return [keyword.strip() for keyword in ll_keywords] + + @field_validator("conversation_history", mode="after") + @classmethod + def conversation_history_role_check( + cls, conversation_history: List[dict[str, Any]] | None + ) -> List[dict[str, Any]] | None: + if conversation_history is None: + return None + for msg in conversation_history: + if "role" not in msg or msg["role"] not in {"user", "assistant"}: + raise ValueError( + "Each message must have a 'role' key with value 'user' or 'assistant'." + ) + return conversation_history class QueryResponse(BaseModel): - response: str + response: str = Field( + description="The generated response", + ) class InsertTextRequest(BaseModel): - text: str = Field(min_length=1) + text: str = Field( + min_length=1, + description="The text to insert", + ) @field_validator("text", mode="after") @classmethod @@ -683,7 +753,10 @@ class InsertTextRequest(BaseModel): class InsertTextsRequest(BaseModel): - texts: list[str] = Field(min_length=1) + texts: list[str] = Field( + min_length=1, + description="The texts to insert", + ) @field_validator("texts", mode="after") @classmethod