improved typing

This commit is contained in:
Yannick Stephan
2025-02-15 22:37:12 +01:00
parent 8d0d8b8279
commit eaf1d553d2
9 changed files with 52 additions and 35 deletions

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
Literal, Literal,
Optional,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union,
) )
import numpy as np import numpy as np
@@ -115,7 +115,7 @@ class BaseVectorStorage(StorageNameSpace):
class BaseKVStorage(StorageNameSpace): class BaseKVStorage(StorageNameSpace):
embedding_func: EmbeddingFunc | None = None embedding_func: EmbeddingFunc | None = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> dict[str, Any] | None:
raise NotImplementedError raise NotImplementedError
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -157,21 +157,23 @@ class BaseGraphStorage(StorageNameSpace):
"""Get a node by its id.""" """Get a node by its id."""
async def get_node(self, node_id: str) -> Union[dict[str, str], None]: async def get_node(self, node_id: str) -> dict[str, str] | None:
raise NotImplementedError raise NotImplementedError
"""Get an edge by its source and target node ids.""" """Get an edge by its source and target node ids."""
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self,
) -> Union[dict[str, str], None]: source_node_id: str,
target_node_id: str
) -> dict[str, str] | None :
raise NotImplementedError raise NotImplementedError
"""Get all edges connected to a node.""" """Get all edges connected to a node."""
async def get_node_edges( async def get_node_edges(
self, source_node_id: str self, source_node_id: str
) -> Union[list[tuple[str, str]], None]: ) -> list[tuple[str, str]] | None:
raise NotImplementedError raise NotImplementedError
"""Upsert a node into the graph.""" """Upsert a node into the graph."""
@@ -236,9 +238,9 @@ class DocProcessingStatus:
"""ISO format timestamp when document was created""" """ISO format timestamp when document was created"""
updated_at: str updated_at: str
"""ISO format timestamp when document was last updated""" """ISO format timestamp when document was last updated"""
chunks_count: Optional[int] = None chunks_count: int | None = None
"""Number of chunks after splitting, used for processing""" """Number of chunks after splitting, used for processing"""
error: Optional[str] = None error: str | None = None
"""Error message if failed""" """Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
"""Additional metadata""" """Additional metadata"""

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import httpx import httpx
from typing import Literal from typing import Literal

View File

@@ -6,7 +6,7 @@ import configparser
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast from typing import Any, AsyncIterator, Callable, Iterator, cast
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -314,7 +314,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls.""" """Maximum number of concurrent embedding function calls."""
# LLM Configuration # LLM Configuration
llm_model_func: Union[Callable[..., object], None] = None llm_model_func: Callable[..., object] | None = None
"""Function for interacting with the large language model (LLM). Must be set before use.""" """Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -354,7 +354,7 @@ class LightRAG:
chunking_func: Callable[ chunking_func: Callable[
[ [
str, str,
Optional[str], str | None,
bool, bool,
int, int,
int, int,

View File

@@ -1,4 +1,6 @@
from typing import List, Dict, Callable, Any from __future__ import annotations
from typing import Callable, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -23,7 +25,7 @@ class Model(BaseModel):
..., ...,
description="A function that generates the response from the llm. The response must be a string", description="A function that generates the response from the llm. The response must be a string",
) )
kwargs: Dict[str, Any] = Field( kwargs: dict[str, Any] = Field(
..., ...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc", description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
) )
@@ -57,7 +59,7 @@ class MultiModel:
``` ```
""" """
def __init__(self, models: List[Model]): def __init__(self, models: list[Model]):
self._models = models self._models = models
self._current_model = 0 self._current_model = 0

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Iterable from typing import Iterable

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio import asyncio
import json import json
import re import re
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from typing import Any, AsyncIterator, Union from typing import Any, AsyncIterator
from collections import Counter, defaultdict from collections import Counter, defaultdict
from .utils import ( from .utils import (
logger, logger,
@@ -36,7 +38,7 @@ import time
def chunking_by_token_size( def chunking_by_token_size(
content: str, content: str,
split_by_character: Union[str, None] = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
overlap_token_size: int = 128, overlap_token_size: int = 128,
max_token_size: int = 1024, max_token_size: int = 1024,
@@ -297,7 +299,7 @@ async def extract_entities(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
global_config: dict[str, str], global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> Union[BaseGraphStorage, None]: ) -> BaseGraphStorage | None:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[ enable_llm_cache_for_entity_extract: bool = global_config[

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
GRAPH_FIELD_SEP = "<SEP>" GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {} PROMPTS = {}

View File

@@ -1,16 +1,19 @@
from __future__ import annotations
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Dict, Any from typing import Any
class GPTKeywordExtractionFormat(BaseModel): class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str] high_level_keywords: list[str]
low_level_keywords: List[str] low_level_keywords: list[str]
class KnowledgeGraphNode(BaseModel): class KnowledgeGraphNode(BaseModel):
id: str id: str
labels: List[str] labels: list[str]
properties: Dict[str, Any] # anything else goes here properties: dict[str, Any] # anything else goes here
class KnowledgeGraphEdge(BaseModel): class KnowledgeGraphEdge(BaseModel):
@@ -18,9 +21,9 @@ class KnowledgeGraphEdge(BaseModel):
type: str type: str
source: str # id of source node source: str # id of source node
target: str # id of target node target: str # id of target node
properties: Dict[str, Any] # anything else goes here properties: dict[str, Any] # anything else goes here
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
nodes: List[KnowledgeGraphNode] = [] nodes: list[KnowledgeGraphNode] = []
edges: List[KnowledgeGraphEdge] = [] edges: list[KnowledgeGraphEdge] = []

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio import asyncio
import html import html
import io import io
@@ -9,7 +11,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Union, List, Optional from typing import Any, Callable
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import bs4 import bs4
@@ -72,7 +74,7 @@ class ReasoningResponse:
tag: str tag: str
def locate_json_string_body_from_string(content: str) -> Union[str, None]: def locate_json_string_body_from_string(content: str) -> str | None:
"""Locate the JSON string body from a string""" """Locate the JSON string body from a string"""
try: try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL) maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -238,7 +240,7 @@ def truncate_list_by_token_size(
return list_data return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str: def list_of_list_to_csv(data: list[list[str]]) -> str:
output = io.StringIO() output = io.StringIO()
writer = csv.writer( writer = csv.writer(
output, output,
@@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
return output.getvalue() return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]: def csv_string_to_list(csv_string: str) -> list[list[str]]:
# Clean the string by removing NUL characters # Clean the string by removing NUL characters
cleaned_string = csv_string.replace("\0", "") cleaned_string = csv_string.replace("\0", "")
@@ -382,7 +384,7 @@ async def get_best_cached_response(
llm_func=None, llm_func=None,
original_prompt=None, original_prompt=None,
cache_type=None, cache_type=None,
) -> Union[str, None]: ) -> str | None:
logger.debug( logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}" f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
) )
@@ -486,7 +488,7 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2) return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple: def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple:
"""Quantize embedding to specified bits""" """Quantize embedding to specified bits"""
# Convert list to numpy array if needed # Convert list to numpy array if needed
if isinstance(embedding, list): if isinstance(embedding, list):
@@ -577,9 +579,9 @@ class CacheData:
args_hash: str args_hash: str
content: str content: str
prompt: str prompt: str
quantized: Optional[np.ndarray] = None quantized: np.ndarray | None = None
min_val: Optional[float] = None min_val: float | None = None
max_val: Optional[float] = None max_val: float | None = None
mode: str = "default" mode: str = "default"
cache_type: str = "query" cache_type: str = "query"