improved typing

This commit is contained in:
Yannick Stephan
2025-02-15 22:37:12 +01:00
parent 8d0d8b8279
commit eaf1d553d2
9 changed files with 52 additions and 35 deletions

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
Literal,
Optional,
TypedDict,
TypeVar,
Union,
)
import numpy as np
@@ -115,7 +115,7 @@ class BaseVectorStorage(StorageNameSpace):
class BaseKVStorage(StorageNameSpace):
embedding_func: EmbeddingFunc | None = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
raise NotImplementedError
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -157,21 +157,23 @@ class BaseGraphStorage(StorageNameSpace):
"""Get a node by its id."""
async def get_node(self, node_id: str) -> Union[dict[str, str], None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
raise NotImplementedError
"""Get an edge by its source and target node ids."""
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict[str, str], None]:
self,
source_node_id: str,
target_node_id: str
) -> dict[str, str] | None :
raise NotImplementedError
"""Get all edges connected to a node."""
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
) -> list[tuple[str, str]] | None:
raise NotImplementedError
"""Upsert a node into the graph."""
@@ -236,9 +238,9 @@ class DocProcessingStatus:
"""ISO format timestamp when document was created"""
updated_at: str
"""ISO format timestamp when document was last updated"""
chunks_count: Optional[int] = None
chunks_count: int | None = None
"""Number of chunks after splitting, used for processing"""
error: Optional[str] = None
error: str | None = None
"""Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict)
"""Additional metadata"""

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import httpx
from typing import Literal

View File

@@ -6,7 +6,7 @@ import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast
from typing import Any, AsyncIterator, Callable, Iterator, cast
from .base import (
BaseGraphStorage,
@@ -314,7 +314,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
llm_model_func: Union[Callable[..., object], None] = None
llm_model_func: Callable[..., object] | None = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -354,7 +354,7 @@ class LightRAG:
chunking_func: Callable[
[
str,
Optional[str],
str | None,
bool,
int,
int,

View File

@@ -1,4 +1,6 @@
from typing import List, Dict, Callable, Any
from __future__ import annotations
from typing import Callable, Any
from pydantic import BaseModel, Field
@@ -23,7 +25,7 @@ class Model(BaseModel):
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: Dict[str, Any] = Field(
kwargs: dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
@@ -57,7 +59,7 @@ class MultiModel:
```
"""
def __init__(self, models: List[Model]):
def __init__(self, models: list[Model]):
self._models = models
self._current_model = 0

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Iterable

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import json
import re
from tqdm.asyncio import tqdm as tqdm_async
from typing import Any, AsyncIterator, Union
from typing import Any, AsyncIterator
from collections import Counter, defaultdict
from .utils import (
logger,
@@ -36,7 +38,7 @@ import time
def chunking_by_token_size(
content: str,
split_by_character: Union[str, None] = None,
split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
@@ -297,7 +299,7 @@ async def extract_entities(
relationships_vdb: BaseVectorStorage,
global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None,
) -> Union[BaseGraphStorage, None]:
) -> BaseGraphStorage | None:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {}

View File

@@ -1,16 +1,19 @@
from __future__ import annotations
from pydantic import BaseModel
from typing import List, Dict, Any
from typing import Any
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
high_level_keywords: list[str]
low_level_keywords: list[str]
class KnowledgeGraphNode(BaseModel):
id: str
labels: List[str]
properties: Dict[str, Any] # anything else goes here
labels: list[str]
properties: dict[str, Any] # anything else goes here
class KnowledgeGraphEdge(BaseModel):
@@ -18,9 +21,9 @@ class KnowledgeGraphEdge(BaseModel):
type: str
source: str # id of source node
target: str # id of target node
properties: Dict[str, Any] # anything else goes here
properties: dict[str, Any] # anything else goes here
class KnowledgeGraph(BaseModel):
nodes: List[KnowledgeGraphNode] = []
edges: List[KnowledgeGraphEdge] = []
nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = []

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import html
import io
@@ -9,7 +11,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Callable, Union, List, Optional
from typing import Any, Callable
import xml.etree.ElementTree as ET
import bs4
@@ -72,7 +74,7 @@ class ReasoningResponse:
tag: str
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
def locate_json_string_body_from_string(content: str) -> str | None:
"""Locate the JSON string body from a string"""
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -238,7 +240,7 @@ def truncate_list_by_token_size(
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
def list_of_list_to_csv(data: list[list[str]]) -> str:
output = io.StringIO()
writer = csv.writer(
output,
@@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
def csv_string_to_list(csv_string: str) -> list[list[str]]:
# Clean the string by removing NUL characters
cleaned_string = csv_string.replace("\0", "")
@@ -382,7 +384,7 @@ async def get_best_cached_response(
llm_func=None,
original_prompt=None,
cache_type=None,
) -> Union[str, None]:
) -> str | None:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
@@ -486,7 +488,7 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
@@ -577,9 +579,9 @@ class CacheData:
args_hash: str
content: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
quantized: np.ndarray | None = None
min_val: float | None = None
max_val: float | None = None
mode: str = "default"
cache_type: str = "query"