diff --git a/lightrag/base.py b/lightrag/base.py index ae451dda..1d7a0a98 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import os from dataclasses import dataclass, field from enum import Enum from typing import ( Any, Literal, - Optional, TypedDict, TypeVar, - Union, ) import numpy as np @@ -115,7 +115,7 @@ class BaseVectorStorage(StorageNameSpace): class BaseKVStorage(StorageNameSpace): embedding_func: EmbeddingFunc | None = None - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: raise NotImplementedError async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -157,21 +157,23 @@ class BaseGraphStorage(StorageNameSpace): """Get a node by its id.""" - async def get_node(self, node_id: str) -> Union[dict[str, str], None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: raise NotImplementedError """Get an edge by its source and target node ids.""" async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict[str, str], None]: + self, + source_node_id: str, + target_node_id: str + ) -> dict[str, str] | None : raise NotImplementedError """Get all edges connected to a node.""" async def get_node_edges( self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + ) -> list[tuple[str, str]] | None: raise NotImplementedError """Upsert a node into the graph.""" @@ -236,9 +238,9 @@ class DocProcessingStatus: """ISO format timestamp when document was created""" updated_at: str """ISO format timestamp when document was last updated""" - chunks_count: Optional[int] = None + chunks_count: int | None = None """Number of chunks after splitting, used for processing""" - error: Optional[str] = None + error: str | None = None """Error message if failed""" metadata: dict[str, Any] = field(default_factory=dict) """Additional metadata""" diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py index 5de6b334..ae756f85 100644 --- a/lightrag/exceptions.py +++ b/lightrag/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import httpx from typing import Literal diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index af241f65..fed555a2 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -6,7 +6,7 @@ import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast +from typing import Any, AsyncIterator, Callable, Iterator, cast from .base import ( BaseGraphStorage, @@ -314,7 +314,7 @@ class LightRAG: """Maximum number of concurrent embedding function calls.""" # LLM Configuration - llm_model_func: Union[Callable[..., object], None] = None + llm_model_func: Callable[..., object] | None = None """Function for interacting with the large language model (LLM). Must be set before use.""" llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" @@ -354,7 +354,7 @@ class LightRAG: chunking_func: Callable[ [ str, - Optional[str], + str | None, bool, int, int, diff --git a/lightrag/llm.py b/lightrag/llm.py index b4baef68..e5f98cf8 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,4 +1,6 @@ -from typing import List, Dict, Callable, Any +from __future__ import annotations + +from typing import Callable, Any from pydantic import BaseModel, Field @@ -23,7 +25,7 @@ class Model(BaseModel): ..., description="A function that generates the response from the llm. The response must be a string", ) - kwargs: Dict[str, Any] = Field( + kwargs: dict[str, Any] = Field( ..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc", ) @@ -57,7 +59,7 @@ class MultiModel: ``` """ - def __init__(self, models: List[Model]): + def __init__(self, models: list[Model]): self._models = models self._current_model = 0 diff --git a/lightrag/namespace.py b/lightrag/namespace.py index ba8e3072..77e04c9e 100644 --- a/lightrag/namespace.py +++ b/lightrag/namespace.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterable diff --git a/lightrag/operate.py b/lightrag/operate.py index f2bd6218..37e7523f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import json import re from tqdm.asyncio import tqdm as tqdm_async -from typing import Any, AsyncIterator, Union +from typing import Any, AsyncIterator from collections import Counter, defaultdict from .utils import ( logger, @@ -36,7 +38,7 @@ import time def chunking_by_token_size( content: str, - split_by_character: Union[str, None] = None, + split_by_character: str | None = None, split_by_character_only: bool = False, overlap_token_size: int = 128, max_token_size: int = 1024, @@ -297,7 +299,7 @@ async def extract_entities( relationships_vdb: BaseVectorStorage, global_config: dict[str, str], llm_response_cache: BaseKVStorage | None = None, -) -> Union[BaseGraphStorage, None]: +) -> BaseGraphStorage | None: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 160663d9..f4f5e38a 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + GRAPH_FIELD_SEP = "" PROMPTS = {} diff --git a/lightrag/types.py b/lightrag/types.py index 9c8e0099..2510bed3 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,16 +1,19 @@ + +from __future__ import annotations + from pydantic import BaseModel -from typing import List, Dict, Any +from typing import Any class GPTKeywordExtractionFormat(BaseModel): - high_level_keywords: List[str] - low_level_keywords: List[str] + high_level_keywords: list[str] + low_level_keywords: list[str] class KnowledgeGraphNode(BaseModel): id: str - labels: List[str] - properties: Dict[str, Any] # anything else goes here + labels: list[str] + properties: dict[str, Any] # anything else goes here class KnowledgeGraphEdge(BaseModel): @@ -18,9 +21,9 @@ class KnowledgeGraphEdge(BaseModel): type: str source: str # id of source node target: str # id of target node - properties: Dict[str, Any] # anything else goes here + properties: dict[str, Any] # anything else goes here class KnowledgeGraph(BaseModel): - nodes: List[KnowledgeGraphNode] = [] - edges: List[KnowledgeGraphEdge] = [] + nodes: list[KnowledgeGraphNode] = [] + edges: list[KnowledgeGraphEdge] = [] diff --git a/lightrag/utils.py b/lightrag/utils.py index 9b18d0c2..5b86ee78 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import html import io @@ -9,7 +11,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Callable, Union, List, Optional +from typing import Any, Callable import xml.etree.ElementTree as ET import bs4 @@ -72,7 +74,7 @@ class ReasoningResponse: tag: str -def locate_json_string_body_from_string(content: str) -> Union[str, None]: +def locate_json_string_body_from_string(content: str) -> str | None: """Locate the JSON string body from a string""" try: maybe_json_str = re.search(r"{.*}", content, re.DOTALL) @@ -238,7 +240,7 @@ def truncate_list_by_token_size( return list_data -def list_of_list_to_csv(data: List[List[str]]) -> str: +def list_of_list_to_csv(data: list[list[str]]) -> str: output = io.StringIO() writer = csv.writer( output, @@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str: return output.getvalue() -def csv_string_to_list(csv_string: str) -> List[List[str]]: +def csv_string_to_list(csv_string: str) -> list[list[str]]: # Clean the string by removing NUL characters cleaned_string = csv_string.replace("\0", "") @@ -382,7 +384,7 @@ async def get_best_cached_response( llm_func=None, original_prompt=None, cache_type=None, -) -> Union[str, None]: +) -> str | None: logger.debug( f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}" ) @@ -486,7 +488,7 @@ def cosine_similarity(v1, v2): return dot_product / (norm1 * norm2) -def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple: +def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple: """Quantize embedding to specified bits""" # Convert list to numpy array if needed if isinstance(embedding, list): @@ -577,9 +579,9 @@ class CacheData: args_hash: str content: str prompt: str - quantized: Optional[np.ndarray] = None - min_val: Optional[float] = None - max_val: Optional[float] = None + quantized: np.ndarray | None = None + min_val: float | None = None + max_val: float | None = None mode: str = "default" cache_type: str = "query"