improved typing

This commit is contained in:
Yannick Stephan
2025-02-15 22:37:12 +01:00
parent 8d0d8b8279
commit eaf1d553d2
9 changed files with 52 additions and 35 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import html
import io
@@ -9,7 +11,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Callable, Union, List, Optional
from typing import Any, Callable
import xml.etree.ElementTree as ET
import bs4
@@ -72,7 +74,7 @@ class ReasoningResponse:
tag: str
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
def locate_json_string_body_from_string(content: str) -> str | None:
"""Locate the JSON string body from a string"""
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -238,7 +240,7 @@ def truncate_list_by_token_size(
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
def list_of_list_to_csv(data: list[list[str]]) -> str:
output = io.StringIO()
writer = csv.writer(
output,
@@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
def csv_string_to_list(csv_string: str) -> list[list[str]]:
# Clean the string by removing NUL characters
cleaned_string = csv_string.replace("\0", "")
@@ -382,7 +384,7 @@ async def get_best_cached_response(
llm_func=None,
original_prompt=None,
cache_type=None,
) -> Union[str, None]:
) -> str | None:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
@@ -486,7 +488,7 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
@@ -577,9 +579,9 @@ class CacheData:
args_hash: str
content: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
quantized: np.ndarray | None = None
min_val: float | None = None
max_val: float | None = None
mode: str = "default"
cache_type: str = "query"