improved typing
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@@ -115,7 +115,7 @@ class BaseVectorStorage(StorageNameSpace):
|
||||
class BaseKVStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
@@ -157,21 +157,23 @@ class BaseGraphStorage(StorageNameSpace):
|
||||
|
||||
"""Get a node by its id."""
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict[str, str], None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get an edge by its source and target node ids."""
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict[str, str], None]:
|
||||
self,
|
||||
source_node_id: str,
|
||||
target_node_id: str
|
||||
) -> dict[str, str] | None :
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get all edges connected to a node."""
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
) -> list[tuple[str, str]] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Upsert a node into the graph."""
|
||||
@@ -236,9 +238,9 @@ class DocProcessingStatus:
|
||||
"""ISO format timestamp when document was created"""
|
||||
updated_at: str
|
||||
"""ISO format timestamp when document was last updated"""
|
||||
chunks_count: Optional[int] = None
|
||||
chunks_count: int | None = None
|
||||
"""Number of chunks after splitting, used for processing"""
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
"""Error message if failed"""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional metadata"""
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
from typing import Literal
|
||||
|
||||
|
@@ -6,7 +6,7 @@ import configparser
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
||||
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -314,7 +314,7 @@ class LightRAG:
|
||||
"""Maximum number of concurrent embedding function calls."""
|
||||
|
||||
# LLM Configuration
|
||||
llm_model_func: Union[Callable[..., object], None] = None
|
||||
llm_model_func: Callable[..., object] | None = None
|
||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@@ -354,7 +354,7 @@ class LightRAG:
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
Optional[str],
|
||||
str | None,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from typing import List, Dict, Callable, Any
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -23,7 +25,7 @@ class Model(BaseModel):
|
||||
...,
|
||||
description="A function that generates the response from the llm. The response must be a string",
|
||||
)
|
||||
kwargs: Dict[str, Any] = Field(
|
||||
kwargs: dict[str, Any] = Field(
|
||||
...,
|
||||
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
||||
)
|
||||
@@ -57,7 +59,7 @@ class MultiModel:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, models: List[Model]):
|
||||
def __init__(self, models: list[Model]):
|
||||
self._models = models
|
||||
self._current_model = 0
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from typing import Any, AsyncIterator, Union
|
||||
from typing import Any, AsyncIterator
|
||||
from collections import Counter, defaultdict
|
||||
from .utils import (
|
||||
logger,
|
||||
@@ -36,7 +38,7 @@ import time
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str,
|
||||
split_by_character: Union[str, None] = None,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
overlap_token_size: int = 128,
|
||||
max_token_size: int = 1024,
|
||||
@@ -297,7 +299,7 @@ async def extract_entities(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict[str, str],
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
) -> BaseGraphStorage | None:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
PROMPTS = {}
|
||||
|
@@ -1,16 +1,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
|
||||
class GPTKeywordExtractionFormat(BaseModel):
|
||||
high_level_keywords: List[str]
|
||||
low_level_keywords: List[str]
|
||||
high_level_keywords: list[str]
|
||||
low_level_keywords: list[str]
|
||||
|
||||
|
||||
class KnowledgeGraphNode(BaseModel):
|
||||
id: str
|
||||
labels: List[str]
|
||||
properties: Dict[str, Any] # anything else goes here
|
||||
labels: list[str]
|
||||
properties: dict[str, Any] # anything else goes here
|
||||
|
||||
|
||||
class KnowledgeGraphEdge(BaseModel):
|
||||
@@ -18,9 +21,9 @@ class KnowledgeGraphEdge(BaseModel):
|
||||
type: str
|
||||
source: str # id of source node
|
||||
target: str # id of target node
|
||||
properties: Dict[str, Any] # anything else goes here
|
||||
properties: dict[str, Any] # anything else goes here
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
nodes: List[KnowledgeGraphNode] = []
|
||||
edges: List[KnowledgeGraphEdge] = []
|
||||
nodes: list[KnowledgeGraphNode] = []
|
||||
edges: list[KnowledgeGraphEdge] = []
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import io
|
||||
@@ -9,7 +11,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Union, List, Optional
|
||||
from typing import Any, Callable
|
||||
import xml.etree.ElementTree as ET
|
||||
import bs4
|
||||
|
||||
@@ -72,7 +74,7 @@ class ReasoningResponse:
|
||||
tag: str
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
def locate_json_string_body_from_string(content: str) -> str | None:
|
||||
"""Locate the JSON string body from a string"""
|
||||
try:
|
||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||
@@ -238,7 +240,7 @@ def truncate_list_by_token_size(
|
||||
return list_data
|
||||
|
||||
|
||||
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(
|
||||
output,
|
||||
@@ -251,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
||||
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
||||
# Clean the string by removing NUL characters
|
||||
cleaned_string = csv_string.replace("\0", "")
|
||||
|
||||
@@ -382,7 +384,7 @@ async def get_best_cached_response(
|
||||
llm_func=None,
|
||||
original_prompt=None,
|
||||
cache_type=None,
|
||||
) -> Union[str, None]:
|
||||
) -> str | None:
|
||||
logger.debug(
|
||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||
)
|
||||
@@ -486,7 +488,7 @@ def cosine_similarity(v1, v2):
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
||||
def quantize_embedding(embedding: np.ndarray | list[float], bits: int=8) -> tuple:
|
||||
"""Quantize embedding to specified bits"""
|
||||
# Convert list to numpy array if needed
|
||||
if isinstance(embedding, list):
|
||||
@@ -577,9 +579,9 @@ class CacheData:
|
||||
args_hash: str
|
||||
content: str
|
||||
prompt: str
|
||||
quantized: Optional[np.ndarray] = None
|
||||
min_val: Optional[float] = None
|
||||
max_val: Optional[float] = None
|
||||
quantized: np.ndarray | None = None
|
||||
min_val: float | None = None
|
||||
max_val: float | None = None
|
||||
mode: str = "default"
|
||||
cache_type: str = "query"
|
||||
|
||||
|
Reference in New Issue
Block a user