added type validation

This commit is contained in:
Yannick Stephan
2025-02-16 19:24:12 +01:00
parent b09589cfd9
commit 7f5f44a646

View File

@@ -14,7 +14,7 @@ from fastapi.staticfiles import StaticFiles
import logging import logging
import argparse import argparse
from typing import List, Any, Literal, Optional, Dict from typing import List, Any, Literal, Optional, Dict
from pydantic import BaseModel from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__ from lightrag.api import __api_version__
@@ -629,42 +629,42 @@ class QueryRequest(BaseModel):
query: str query: str
"""Specifies the retrieval mode""" """Specifies the retrieval mode"""
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "hybrid" mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(default="hybrid")
"""If True, only returns the retrieved context without generating a response.""" """If True, only returns the retrieved context without generating a response."""
only_need_context: Optional[bool] = None only_need_context: Optional[bool] = Field(default=None)
"""If True, only returns the generated prompt without producing a response.""" """If True, only returns the generated prompt without producing a response."""
only_need_prompt: Optional[bool] = None only_need_prompt: Optional[bool] = Field(default=None)
"""Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
response_type: Optional[str] = None response_type: Optional[str] = Field(default=None)
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
top_k: Optional[int] = None top_k: Optional[int] = Field(default=None)
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_text_unit: Optional[int] = None max_token_for_text_unit: Optional[int] = Field(default=None)
"""Maximum number of tokens allocated for relationship descriptions in global retrieval.""" """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_global_context: Optional[int] = None max_token_for_global_context: Optional[int] = Field(default=None)
"""Maximum number of tokens allocated for entity descriptions in local retrieval.""" """Maximum number of tokens allocated for entity descriptions in local retrieval."""
max_token_for_local_context: Optional[int] = None max_token_for_local_context: Optional[int] = Field(default=None)
"""List of high-level keywords to prioritize in retrieval.""" """List of high-level keywords to prioritize in retrieval."""
hl_keywords: Optional[List[str]] = None hl_keywords: Optional[List[str]] = Field(default=None)
"""List of low-level keywords to refine retrieval focus.""" """List of low-level keywords to refine retrieval focus."""
ll_keywords: Optional[List[str]] = None ll_keywords: Optional[List[str]] = Field(default=None)
"""Stores past conversation history to maintain context. """Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}]. Format: [{"role": "user/assistant", "content": "message"}].
""" """
conversation_history: Optional[List[dict[str, Any]]] = None conversation_history: Optional[List[dict[str, Any]]] = Field(default=None)
"""Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
history_turns: Optional[int] = None history_turns: Optional[int] = Field(default=None)
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
@@ -674,9 +674,22 @@ class QueryResponse(BaseModel):
class InsertTextRequest(BaseModel): class InsertTextRequest(BaseModel):
text: str text: str
@field_validator('text', mode='after')
@classmethod
def check_not_empty(cls, text: str) -> str:
if not text:
raise ValueError("Text cannot be empty")
return text
class InsertTextsRequest(BaseModel): class InsertTextsRequest(BaseModel):
texts: list[str] texts: list[str] = Field(default_factory=list)
@field_validator('texts', mode='after')
@classmethod
def check_not_empty(cls, texts: list[str]) -> list[str]:
if not texts:
raise ValueError("Texts cannot be empty")
return texts
class InsertResponse(BaseModel): class InsertResponse(BaseModel):