improved typing
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import io
|
||||
@@ -9,7 +11,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Union, List, Optional
|
||||
from typing import Any, Callable
|
||||
import xml.etree.ElementTree as ET
|
||||
import bs4
|
||||
|
||||
@@ -72,7 +74,7 @@ class ReasoningResponse:
|
||||
tag: str
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
def locate_json_string_body_from_string(content: str) -> str | None:
|
||||
"""Locate the JSON string body from a string"""
|
||||
try:
|
||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||
@@ -238,7 +240,7 @@ def truncate_list_by_token_size(
|
||||
return list_data
|
||||
|
||||
|
||||
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(
|
||||
output,
|
||||
@@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
||||
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
||||
# Clean the string by removing NUL characters
|
||||
cleaned_string = csv_string.replace("\0", "")
|
||||
|
||||
@@ -382,7 +384,7 @@ async def get_best_cached_response(
|
||||
llm_func=None,
|
||||
original_prompt=None,
|
||||
cache_type=None,
|
||||
) -> Union[str, None]:
|
||||
) -> str | None:
|
||||
logger.debug(
|
||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||
)
|
||||
@@ -486,7 +488,7 @@ def cosine_similarity(v1, v2):
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
||||
def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple:
|
||||
"""Quantize embedding to specified bits"""
|
||||
# Convert list to numpy array if needed
|
||||
if isinstance(embedding, list):
|
||||
@@ -577,9 +579,9 @@ class CacheData:
|
||||
args_hash: str
|
||||
content: str
|
||||
prompt: str
|
||||
quantized: Optional[np.ndarray] = None
|
||||
min_val: Optional[float] = None
|
||||
max_val: Optional[float] = None
|
||||
quantized: np.ndarray | None = None
|
||||
min_val: float | None = None
|
||||
max_val: float | None = None
|
||||
mode: str = "default"
|
||||
cache_type: str = "query"
|
||||
|
||||
|
Reference in New Issue
Block a user