Fix linting

This commit is contained in:
yangdx
2025-02-05 22:29:07 +08:00
parent f703334ce4
commit 1a61d9ee7f
2 changed files with 26 additions and 3 deletions

View File

@@ -41,12 +41,14 @@ from .ollama_api import ollama_server_infos
# Load environment variables
load_dotenv()
class RAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
# Initialize rag storage config
rag_storage_config = RAGStorageConfig()
@@ -592,6 +594,7 @@ class SearchMode(str, Enum):
hybrid = "hybrid"
mix = "mix"
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid

View File

@@ -12,6 +12,7 @@ import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
@@ -21,8 +22,10 @@ class OllamaServerInfos:
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
ollama_server_infos = OllamaServerInfos()
# query mode according to query prefix (bypass is not LightRAG quer mode)
class SearchMode(str, Enum):
naive = "naive"
@@ -32,11 +35,13 @@ class SearchMode(str, Enum):
mix = "mix"
bypass = "bypass"
class OllamaMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel):
model: str
messages: List[OllamaMessage]
@@ -44,12 +49,14 @@ class OllamaChatRequest(BaseModel):
options: Optional[Dict[str, Any]] = None
system: Optional[str] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaGenerateRequest(BaseModel):
model: str
prompt: str
@@ -57,6 +64,7 @@ class OllamaGenerateRequest(BaseModel):
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
@@ -70,9 +78,11 @@ class OllamaGenerateResponse(BaseModel):
eval_count: Optional[int]
eval_duration: Optional[int]
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
@@ -81,6 +91,7 @@ class OllamaModelDetails(BaseModel):
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
@@ -89,9 +100,11 @@ class OllamaModel(BaseModel):
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: List[OllamaModel]
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
Chinese characters: approximately 1.5 tokens per character
@@ -106,6 +119,7 @@ def estimate_tokens(text: str) -> int:
return int(tokens)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
@@ -127,6 +141,7 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
return query, SearchMode.hybrid
class OllamaAPI:
def __init__(self, rag: LightRAG):
self.rag = rag
@@ -333,10 +348,13 @@ class OllamaAPI:
"stream": request.stream,
"only_need_context": False,
"conversation_history": conversation_history,
"top_k": self.rag.args.top_k if hasattr(self.rag, 'args') else 50,
"top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50,
}
if hasattr(self.rag, 'args') and self.rag.args.history_turns is not None:
if (
hasattr(self.rag, "args")
and self.rag.args.history_turns is not None
):
param_dict["history_turns"] = self.rag.args.history_turns
query_param = QueryParam(**param_dict)
@@ -521,7 +539,9 @@ class OllamaAPI:
**self.rag.llm_model_kwargs,
)
else:
response_text = await self.rag.aquery(cleaned_query, param=query_param)
response_text = await self.rag.aquery(
cleaned_query, param=query_param
)
last_chunk_time = time.time_ns()