diff --git a/README.md b/README.md index 62f21a65..487c65f5 100644 --- a/README.md +++ b/README.md @@ -428,9 +428,9 @@ And using a routine to process news documents. ```python rag = LightRAG(..) -await rag.apipeline_enqueue_documents(string_or_strings) +await rag.apipeline_enqueue_documents(input) # Your routine in loop -await rag.apipeline_process_enqueue_documents(string_or_strings) +await rag.apipeline_process_enqueue_documents(input) ``` ### Separate Keyword Extraction diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index f5269fae..9c90424e 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -113,7 +113,24 @@ async def main(): ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.set_storage_client(db_client=oracle_db) + + for storage in [ + rag.vector_db_storage_cls, + rag.graph_storage_cls, + rag.doc_status, + rag.full_docs, + rag.text_chunks, + rag.llm_response_cache, + rag.key_string_value_json_storage_cls, + rag.chunks_vdb, + rag.relationships_vdb, + rag.entities_vdb, + rag.graph_storage_cls, + rag.chunk_entity_relation_graph, + rag.llm_response_cache, + ]: + # set client + storage.db = oracle_db # Extract and Insert into LightRAG storage with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: diff --git a/examples/test_chromadb.py b/examples/test_chromadb.py index 5293f05d..99090a6d 100644 --- a/examples/test_chromadb.py +++ b/examples/test_chromadb.py @@ -17,7 +17,9 @@ if not os.path.exists(WORKING_DIR): # ChromaDB Configuration CHROMADB_USE_LOCAL_PERSISTENT = False # Local PersistentClient Configuration -CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")) +CHROMADB_LOCAL_PATH = os.environ.get( + "CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data") +) # Remote HttpClient Configuration CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost") CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000)) diff --git a/external_bindings/OpenWebuiTool/openwebui_tool.py b/external_bindings/OpenWebuiTool/openwebui_tool.py deleted file mode 100644 index 8df3109c..00000000 --- a/external_bindings/OpenWebuiTool/openwebui_tool.py +++ /dev/null @@ -1,358 +0,0 @@ -""" -OpenWebui Lightrag Integration Tool -================================== - -This tool enables the integration and use of Lightrag within the OpenWebui environment, -providing a seamless interface for RAG (Retrieval-Augmented Generation) operations. - -Author: ParisNeo (parisneoai@gmail.com) -Social: - - Twitter: @ParisNeo_AI - - Reddit: r/lollms - - Instagram: https://www.instagram.com/parisneo_ai/ - -License: Apache 2.0 -Copyright (c) 2024-2025 ParisNeo - -This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems). -For more information, visit: https://github.com/ParisNeo/lollms - -Requirements: - - Python 3.8+ - - OpenWebui - - Lightrag -""" - -# Tool version -__version__ = "1.0.0" -__author__ = "ParisNeo" -__author_email__ = "parisneoai@gmail.com" -__description__ = "Lightrag integration for OpenWebui" - - -import requests -import json -from pydantic import BaseModel, Field -from typing import Callable, Any, Literal, Union, List, Tuple - - -class StatusEventEmitter: - def __init__(self, event_emitter: Callable[[dict], Any] = None): - self.event_emitter = event_emitter - - async def emit(self, description="Unknown State", status="in_progress", done=False): - if self.event_emitter: - await self.event_emitter( - { - "type": "status", - "data": { - "status": status, - "description": description, - "done": done, - }, - } - ) - - -class MessageEventEmitter: - def __init__(self, event_emitter: Callable[[dict], Any] = None): - self.event_emitter = event_emitter - - async def emit(self, content="Some message"): - if self.event_emitter: - await self.event_emitter( - { - "type": "message", - "data": { - "content": content, - }, - } - ) - - -class Tools: - class Valves(BaseModel): - LIGHTRAG_SERVER_URL: str = Field( - default="http://localhost:9621/query", - description="The base URL for the LightRag server", - ) - MODE: Literal["naive", "local", "global", "hybrid"] = Field( - default="hybrid", - description="The mode to use for the LightRag query. Options: naive, local, global, hybrid", - ) - ONLY_NEED_CONTEXT: bool = Field( - default=False, - description="If True, only the context is needed from the LightRag response", - ) - DEBUG_MODE: bool = Field( - default=False, - description="If True, debugging information will be emitted", - ) - KEY: str = Field( - default="", - description="Optional Bearer Key for authentication", - ) - MAX_ENTITIES: int = Field( - default=5, - description="Maximum number of entities to keep", - ) - MAX_RELATIONSHIPS: int = Field( - default=5, - description="Maximum number of relationships to keep", - ) - MAX_SOURCES: int = Field( - default=3, - description="Maximum number of sources to keep", - ) - - def __init__(self): - self.valves = self.Valves() - self.headers = { - "Content-Type": "application/json", - "User-Agent": "LightRag-Tool/1.0", - } - - async def query_lightrag( - self, - query: str, - __event_emitter__: Callable[[dict], Any] = None, - ) -> str: - """ - Query the LightRag server and retrieve information. - This function must be called before answering the user question - :params query: The query string to send to the LightRag server. - :return: The response from the LightRag server in Markdown format or raw response. - """ - self.status_emitter = StatusEventEmitter(__event_emitter__) - self.message_emitter = MessageEventEmitter(__event_emitter__) - - lightrag_url = self.valves.LIGHTRAG_SERVER_URL - payload = { - "query": query, - "mode": str(self.valves.MODE), - "stream": False, - "only_need_context": self.valves.ONLY_NEED_CONTEXT, - } - await self.status_emitter.emit("Initializing Lightrag query..") - - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "### Debug Mode Active\n\nDebugging information will be displayed.\n" - ) - await self.message_emitter.emit( - "#### Payload Sent to LightRag Server\n```json\n" - + json.dumps(payload, indent=4) - + "\n```\n" - ) - - # Add Bearer Key to headers if provided - if self.valves.KEY: - self.headers["Authorization"] = f"Bearer {self.valves.KEY}" - - try: - await self.status_emitter.emit("Sending request to LightRag server") - - response = requests.post( - lightrag_url, json=payload, headers=self.headers, timeout=120 - ) - response.raise_for_status() - data = response.json() - await self.status_emitter.emit( - status="complete", - description="LightRag query Succeeded", - done=True, - ) - - # Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response - if self.valves.ONLY_NEED_CONTEXT: - try: - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "#### LightRag Server Response\n```json\n" - + data["response"] - + "\n```\n" - ) - except Exception as ex: - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "#### Exception\n" + str(ex) + "\n" - ) - return f"Exception: {ex}" - return data["response"] - else: - if self.valves.DEBUG_MODE: - await self.message_emitter.emit( - "#### LightRag Server Response\n```json\n" - + data["response"] - + "\n```\n" - ) - await self.status_emitter.emit("Lightrag query success") - return data["response"] - - except requests.exceptions.RequestException as e: - await self.status_emitter.emit( - status="error", - description=f"Error during LightRag query: {str(e)}", - done=True, - ) - return json.dumps({"error": str(e)}) - - def extract_code_blocks( - self, text: str, return_remaining_text: bool = False - ) -> Union[List[dict], Tuple[List[dict], str]]: - """ - This function extracts code blocks from a given text and optionally returns the text without code blocks. - - Parameters: - text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```). - return_remaining_text (bool): If True, also returns the text with code blocks removed. - - Returns: - Union[List[dict], Tuple[List[dict], str]]: - - If return_remaining_text is False: Returns only the list of code block dictionaries - - If return_remaining_text is True: Returns a tuple containing: - * List of code block dictionaries - * String containing the text with all code blocks removed - - Each code block dictionary contains: - - 'index' (int): The index of the code block in the text - - 'file_name' (str): The name of the file extracted from the preceding line, if available - - 'content' (str): The content of the code block - - 'type' (str): The type of the code block - - 'is_complete' (bool): True if the block has a closing tag, False otherwise - """ - remaining = text - bloc_index = 0 - first_index = 0 - indices = [] - text_without_blocks = text - - # Find all code block delimiters - while len(remaining) > 0: - try: - index = remaining.index("```") - indices.append(index + first_index) - remaining = remaining[index + 3 :] - first_index += index + 3 - bloc_index += 1 - except Exception: - if bloc_index % 2 == 1: - index = len(remaining) - indices.append(index) - remaining = "" - - code_blocks = [] - is_start = True - - # Process code blocks and build text without blocks if requested - if return_remaining_text: - text_parts = [] - last_end = 0 - - for index, code_delimiter_position in enumerate(indices): - if is_start: - block_infos = { - "index": len(code_blocks), - "file_name": "", - "section": "", - "content": "", - "type": "", - "is_complete": False, - } - - # Store text before code block if returning remaining text - if return_remaining_text: - text_parts.append(text[last_end:code_delimiter_position].strip()) - - # Check the preceding line for file name - preceding_text = text[:code_delimiter_position].strip().splitlines() - if preceding_text: - last_line = preceding_text[-1].strip() - if last_line.startswith("") and last_line.endswith( - "" - ): - file_name = last_line[ - len("") : -len("") - ].strip() - block_infos["file_name"] = file_name - elif last_line.startswith("## filename:"): - file_name = last_line[len("## filename:") :].strip() - block_infos["file_name"] = file_name - if last_line.startswith("
") and last_line.endswith( - "
" - ): - section = last_line[ - len("
") : -len("
") - ].strip() - block_infos["section"] = section - - sub_text = text[code_delimiter_position + 3 :] - if len(sub_text) > 0: - try: - find_space = sub_text.index(" ") - except Exception: - find_space = int(1e10) - try: - find_return = sub_text.index("\n") - except Exception: - find_return = int(1e10) - next_index = min(find_return, find_space) - if "{" in sub_text[:next_index]: - next_index = 0 - start_pos = next_index - - if code_delimiter_position + 3 < len(text) and text[ - code_delimiter_position + 3 - ] in ["\n", " ", "\t"]: - block_infos["type"] = "language-specific" - else: - block_infos["type"] = sub_text[:next_index] - - if index + 1 < len(indices): - next_pos = indices[index + 1] - code_delimiter_position - if ( - next_pos - 3 < len(sub_text) - and sub_text[next_pos - 3] == "`" - ): - block_infos["content"] = sub_text[ - start_pos : next_pos - 3 - ].strip() - block_infos["is_complete"] = True - else: - block_infos["content"] = sub_text[ - start_pos:next_pos - ].strip() - block_infos["is_complete"] = False - - if return_remaining_text: - last_end = indices[index + 1] + 3 - else: - block_infos["content"] = sub_text[start_pos:].strip() - block_infos["is_complete"] = False - - if return_remaining_text: - last_end = len(text) - - code_blocks.append(block_infos) - is_start = False - else: - is_start = True - - if return_remaining_text: - # Add any remaining text after the last code block - if last_end < len(text): - text_parts.append(text[last_end:].strip()) - # Join all non-code parts with newlines - text_without_blocks = "\n".join(filter(None, text_parts)) - return code_blocks, text_without_blocks - - return code_blocks - - def clean(self, csv_content: str): - lines = csv_content.splitlines() - if lines: - # Remove spaces around headers and ensure no spaces between commas - header = ",".join([col.strip() for col in lines[0].split(",")]) - lines[0] = header # Replace the first line with the cleaned header - csv_content = "\n".join(lines) - return csv_content diff --git a/lightrag/base.py b/lightrag/base.py index e75167c4..3d4fc022 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import os from dataclasses import dataclass, field from enum import Enum from typing import ( Any, Literal, - Optional, TypedDict, TypeVar, - Union, ) import numpy as np @@ -69,7 +69,7 @@ class QueryParam: ll_keywords: list[str] = field(default_factory=list) """List of low-level keywords to refine retrieval focus.""" - conversation_history: list[dict[str, Any]] = field(default_factory=list) + conversation_history: list[dict[str, str]] = field(default_factory=list) """Stores past conversation history to maintain context. Format: [{"role": "user/assistant", "content": "message"}]. """ @@ -83,19 +83,15 @@ class StorageNameSpace: namespace: str global_config: dict[str, Any] - async def index_done_callback(self): + async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" pass - async def query_done_callback(self): - """Commit the storage operations after querying""" - pass - @dataclass class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc - meta_fields: set = field(default_factory=set) + meta_fields: set[str] = field(default_factory=set) async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError @@ -106,12 +102,20 @@ class BaseVectorStorage(StorageNameSpace): """ raise NotImplementedError + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError + @dataclass class BaseKVStorage(StorageNameSpace): - embedding_func: EmbeddingFunc + embedding_func: EmbeddingFunc | None = None - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: raise NotImplementedError async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -130,50 +134,75 @@ class BaseKVStorage(StorageNameSpace): @dataclass class BaseGraphStorage(StorageNameSpace): - embedding_func: EmbeddingFunc = None + embedding_func: EmbeddingFunc | None = None + """Check if a node exists in the graph.""" async def has_node(self, node_id: str) -> bool: raise NotImplementedError + """Check if an edge exists in the graph.""" + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError + """Get the degree of a node.""" + async def node_degree(self, node_id: str) -> int: raise NotImplementedError + """Get the degree of an edge.""" + async def edge_degree(self, src_id: str, tgt_id: str) -> int: raise NotImplementedError - async def get_node(self, node_id: str) -> Union[dict, None]: + """Get a node by its id.""" + + async def get_node(self, node_id: str) -> dict[str, str] | None: raise NotImplementedError + """Get an edge by its source and target node ids.""" + async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> dict[str, str] | None: raise NotImplementedError - async def get_node_edges( - self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + """Get all edges connected to a node.""" + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: raise NotImplementedError - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + """Upsert a node into the graph.""" + + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: raise NotImplementedError + """Upsert an edge into the graph.""" + async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: raise NotImplementedError - async def delete_node(self, node_id: str): + """Delete a node from the graph.""" + + async def delete_node(self, node_id: str) -> None: raise NotImplementedError - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + """Embed nodes using an algorithm.""" + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") + """Get all labels in the graph.""" + async def get_all_labels(self) -> list[str]: raise NotImplementedError + """Get a knowledge graph of a node.""" + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: @@ -205,9 +234,9 @@ class DocProcessingStatus: """ISO format timestamp when document was created""" updated_at: str """ISO format timestamp when document was last updated""" - chunks_count: Optional[int] = None + chunks_count: int | None = None """Number of chunks after splitting, used for processing""" - error: Optional[str] = None + error: str | None = None """Error message if failed""" metadata: dict[str, Any] = field(default_factory=dict) """Additional metadata""" diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py index 5de6b334..ae756f85 100644 --- a/lightrag/exceptions.py +++ b/lightrag/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import httpx from typing import Literal diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 7b7642d6..cb3b59f1 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -67,7 +67,9 @@ class ChromaVectorDBStorage(BaseVectorStorage): if "token_authn" in auth_provider: headers = { - config.get("auth_header_name", "X-Chroma-Token"): auth_credentials + config.get( + "auth_header_name", "X-Chroma-Token" + ): auth_credentials } elif "basic_authn" in auth_provider: auth_credentials = config.get("auth_credentials", "admin:admin") @@ -154,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) results = self._collection.query( - query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding, + query_embeddings=embedding.tolist() + if not isinstance(embedding, list) + else embedding, n_results=top_k * 2, # Request more results to allow for filtering include=["metadatas", "distances", "documents"], ) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 0dca9e4c..9a5f7e4e 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") await self.delete([entity_id]) - async def delete_entity_relation(self, entity_name: str): + async def delete_entity_relation(self, entity_name: str) -> None: """ Delete relations for a given entity by scanning metadata. """ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 2db8f72a..5d786646 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") - async def delete_entity_relation(self, entity_name: str): + async def delete_entity_relation(self, entity_name: str) -> None: try: relations = [ dp diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ed0dec29..23c3df80 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import asyncio import os import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, Callable, Optional, Type, Union, cast +from typing import Any, AsyncIterator, Callable, Iterator, cast from .base import ( BaseGraphStorage, @@ -92,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = { } # Storage implementation environment variable without default value -STORAGE_ENV_REQUIREMENTS = { +STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { # KV Storage Implementations "JsonKVStorage": [], "MongoKVStorage": [], @@ -179,7 +181,7 @@ STORAGES = { } -def lazy_external_import(module_name: str, class_name: str): +def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: """Lazily import a class from an external module based on the package of the caller.""" # Get the caller's module and package import inspect @@ -188,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str): module = inspect.getmodule(caller_frame) package = module.__package__ if module else None - def import_class(*args, **kwargs): + def import_class(*args: Any, **kwargs: Any): import importlib module = importlib.import_module(module_name, package=package) @@ -305,7 +307,7 @@ class LightRAG: - random_seed: Seed value for reproducibility. """ - embedding_func: EmbeddingFunc = None + embedding_func: EmbeddingFunc | None = None """Function for computing text embeddings. Must be set before use.""" embedding_batch_num: int = 32 @@ -315,7 +317,7 @@ class LightRAG: """Maximum number of concurrent embedding function calls.""" # LLM Configuration - llm_model_func: callable = None + llm_model_func: Callable[..., object] | None = None """Function for interacting with the large language model (LLM). Must be set before use.""" llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" @@ -345,10 +347,8 @@ class LightRAG: # Extensions addon_params: dict[str, Any] = field(default_factory=dict) - """Dictionary for additional parameters and extensions.""" - # extension - addon_params: dict[str, Any] = field(default_factory=dict) + """Dictionary for additional parameters and extensions.""" convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( convert_response_to_json ) @@ -357,7 +357,7 @@ class LightRAG: chunking_func: Callable[ [ str, - Optional[str], + str | None, bool, int, int, @@ -446,77 +446,74 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init LLM - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( + self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore self.embedding_func ) # Initialize all storages - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self.key_string_value_json_storage_cls: type[BaseKVStorage] = ( self._get_storage_class(self.kv_storage) - ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + ) # type: ignore + self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class( self.vector_storage - ) - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + ) # type: ignore + self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class( self.graph_storage - ) - - self.key_string_value_json_storage_cls = partial( + ) # type: ignore + self.key_string_value_json_storage_cls = partial( # type: ignore self.key_string_value_json_storage_cls, global_config=global_config ) - - self.vector_db_storage_cls = partial( + self.vector_db_storage_cls = partial( # type: ignore self.vector_db_storage_cls, global_config=global_config ) - - self.graph_storage_cls = partial( + self.graph_storage_cls = partial( # type: ignore self.graph_storage_cls, global_config=global_config ) # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - self.llm_response_cache = self.key_string_value_json_storage_cls( + self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), embedding_func=self.embedding_func, ) - self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( + self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS ), embedding_func=self.embedding_func, ) - self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( + self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS ), embedding_func=self.embedding_func, ) - self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( + self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION ), embedding_func=self.embedding_func, ) - self.entities_vdb = self.vector_db_storage_cls( + self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES ), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) - self.relationships_vdb = self.vector_db_storage_cls( + self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS ), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) - self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( + self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS ), @@ -530,13 +527,12 @@ class LightRAG: embedding_func=None, ) - # What's for, Is this nessisary ? if self.llm_response_cache and hasattr( self.llm_response_cache, "global_config" ): hashing_kv = self.llm_response_cache else: - hashing_kv = self.key_string_value_json_storage_cls( + hashing_kv = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), @@ -545,7 +541,7 @@ class LightRAG: self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( - self.llm_model_func, + self.llm_model_func, # type: ignore hashing_kv=hashing_kv, **self.llm_model_kwargs, ) @@ -562,68 +558,45 @@ class LightRAG: node_label=nodel_label, max_depth=max_depth ) - def _get_storage_class(self, storage_name: str) -> dict: + def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name) return storage_class - def set_storage_client(self, db_client): - # Deprecated, seting correct value to *_storage of LightRAG insteaded - # Inject db to storage implementation (only tested on Oracle Database) - for storage in [ - self.vector_db_storage_cls, - self.graph_storage_cls, - self.doc_status, - self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.key_string_value_json_storage_cls, - self.chunks_vdb, - self.relationships_vdb, - self.entities_vdb, - self.graph_storage_cls, - self.chunk_entity_relation_graph, - self.llm_response_cache, - ]: - # set client - storage.db = db_client - def insert( self, - string_or_strings: Union[str, list[str]], + input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ): """Sync Insert documents with checkpoint support Args: - string_or_strings: Single document string or list of document strings + input: Single document string or list of document strings split_by_character: if split_by_character is not None, split the string by character, if chunk longer than - chunk_size, split the sub chunk by token size. split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ loop = always_get_an_event_loop() return loop.run_until_complete( - self.ainsert(string_or_strings, split_by_character, split_by_character_only) + self.ainsert(input, split_by_character, split_by_character_only) ) async def ainsert( self, - string_or_strings: Union[str, list[str]], + input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ): """Async Insert documents with checkpoint support Args: - string_or_strings: Single document string or list of document strings + input: Single document string or list of document strings split_by_character: if split_by_character is not None, split the string by character, if chunk longer than - chunk_size, split the sub chunk by token size. split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ - await self.apipeline_enqueue_documents(string_or_strings) + await self.apipeline_enqueue_documents(input) await self.apipeline_process_enqueue_documents( split_by_character, split_by_character_only ) @@ -680,7 +653,7 @@ class LightRAG: if update_storage: await self._insert_done() - async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]): + async def apipeline_enqueue_documents(self, input: str | list[str]): """ Pipeline for Processing Documents @@ -689,11 +662,11 @@ class LightRAG: 3. Filter out already processed documents 4. Enqueue document in status """ - if isinstance(string_or_strings, str): - string_or_strings = [string_or_strings] + if isinstance(input, str): + input = [input] # 1. Remove duplicate contents from the list - unique_contents = list(set(doc.strip() for doc in string_or_strings)) + unique_contents = list(set(doc.strip() for doc in input)) # 2. Generate document IDs and initial status new_docs: dict[str, Any] = { @@ -860,32 +833,32 @@ class LightRAG: raise e async def _insert_done(self): - tasks = [] - for storage_inst in [ - self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.entities_vdb, - self.relationships_vdb, - self.chunks_vdb, - self.chunk_entity_relation_graph, - ]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) + tasks = [ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.chunk_entity_relation_graph, + ] + if storage_inst is not None + ] await asyncio.gather(*tasks) - def insert_custom_kg(self, custom_kg: dict): + def insert_custom_kg(self, custom_kg: dict[str, Any]): loop = always_get_an_event_loop() return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) - async def ainsert_custom_kg(self, custom_kg: dict): + async def ainsert_custom_kg(self, custom_kg: dict[str, Any]): update_storage = False try: # Insert chunks into vector storage - all_chunks_data = {} - chunk_to_source_map = {} - for chunk_data in custom_kg.get("chunks", []): + all_chunks_data: dict[str, dict[str, str]] = {} + chunk_to_source_map: dict[str, str] = {} + for chunk_data in custom_kg.get("chunks", {}): chunk_content = chunk_data["content"] source_id = chunk_data["source_id"] chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-") @@ -895,13 +868,13 @@ class LightRAG: chunk_to_source_map[source_id] = chunk_id update_storage = True - if self.chunks_vdb is not None and all_chunks_data: + if all_chunks_data: await self.chunks_vdb.upsert(all_chunks_data) - if self.text_chunks is not None and all_chunks_data: + if all_chunks_data: await self.text_chunks.upsert(all_chunks_data) # Insert entities into knowledge graph - all_entities_data = [] + all_entities_data: list[dict[str, str]] = [] for entity_data in custom_kg.get("entities", []): entity_name = f'"{entity_data["entity_name"].upper()}"' entity_type = entity_data.get("entity_type", "UNKNOWN") @@ -917,7 +890,7 @@ class LightRAG: ) # Prepare node data - node_data = { + node_data: dict[str, str] = { "entity_type": entity_type, "description": description, "source_id": source_id, @@ -931,7 +904,7 @@ class LightRAG: update_storage = True # Insert relationships into knowledge graph - all_relationships_data = [] + all_relationships_data: list[dict[str, str]] = [] for relationship_data in custom_kg.get("relationships", []): src_id = f'"{relationship_data["src_id"].upper()}"' tgt_id = f'"{relationship_data["tgt_id"].upper()}"' @@ -973,7 +946,7 @@ class LightRAG: "source_id": source_id, }, ) - edge_data = { + edge_data: dict[str, str] = { "src_id": src_id, "tgt_id": tgt_id, "description": description, @@ -983,41 +956,68 @@ class LightRAG: update_storage = True # Insert entities into vector storage if needed - if self.entities_vdb is not None: - data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "content": dp["entity_name"] + dp["description"], - "entity_name": dp["entity_name"], - } - for dp in all_entities_data + data_for_vdb = { + compute_mdhash_id(dp["entity_name"], prefix="ent-"): { + "content": dp["entity_name"] + dp["description"], + "entity_name": dp["entity_name"], } - await self.entities_vdb.upsert(data_for_vdb) + for dp in all_entities_data + } + await self.entities_vdb.upsert(data_for_vdb) # Insert relationships into vector storage if needed - if self.relationships_vdb is not None: - data_for_vdb = { - compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { - "src_id": dp["src_id"], - "tgt_id": dp["tgt_id"], - "content": dp["keywords"] - + dp["src_id"] - + dp["tgt_id"] - + dp["description"], - } - for dp in all_relationships_data + data_for_vdb = { + compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { + "src_id": dp["src_id"], + "tgt_id": dp["tgt_id"], + "content": dp["keywords"] + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], } - await self.relationships_vdb.upsert(data_for_vdb) + for dp in all_relationships_data + } + await self.relationships_vdb.upsert(data_for_vdb) + finally: if update_storage: await self._insert_done() - def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()): + def query( + self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None + ) -> str | Iterator[str]: + """ + Perform a sync query. + + Args: + query (str): The query to be executed. + param (QueryParam): Configuration parameters for query execution. + prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. + + Returns: + str: The result of the query execution. + """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, prompt, param)) + + return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore async def aquery( - self, query: str, prompt: str = "", param: QueryParam = QueryParam() - ): + self, + query: str, + param: QueryParam = QueryParam(), + prompt: str | None = None, + ) -> str | AsyncIterator[str]: + """ + Perform a async query. + + Args: + query (str): The query to be executed. + param (QueryParam): Configuration parameters for query execution. + prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. + + Returns: + str: The result of the query execution. + """ if param.mode in ["local", "global", "hybrid"]: response = await kg_query( query, @@ -1097,7 +1097,7 @@ class LightRAG: async def aquery_with_separate_keyword_extraction( self, query: str, prompt: str, param: QueryParam = QueryParam() - ): + ) -> str | AsyncIterator[str]: """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. @@ -1120,8 +1120,8 @@ class LightRAG: ), ) - param.hl_keywords = (hl_keywords,) - param.ll_keywords = (ll_keywords,) + param.hl_keywords = hl_keywords + param.ll_keywords = ll_keywords # --------------------- # STEP 2: Final Query Logic @@ -1149,7 +1149,7 @@ class LightRAG: self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), global_config=asdict(self), - embedding_func=self.embedding_funcne, + embedding_func=self.embedding_func, ), ) elif param.mode == "naive": @@ -1198,12 +1198,7 @@ class LightRAG: return response async def _query_done(self): - tasks = [] - for storage_inst in [self.llm_response_cache]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) - await asyncio.gather(*tasks) + await self.llm_response_cache.index_done_callback() def delete_by_entity(self, entity_name: str): loop = always_get_an_event_loop() @@ -1225,16 +1220,16 @@ class LightRAG: logger.error(f"Error while deleting entity '{entity_name}': {e}") async def _delete_by_entity_done(self): - tasks = [] - for storage_inst in [ - self.entities_vdb, - self.relationships_vdb, - self.chunk_entity_relation_graph, - ]: - if storage_inst is None: - continue - tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) - await asyncio.gather(*tasks) + await asyncio.gather( + *[ + cast(StorageNameSpace, storage_inst).index_done_callback() + for storage_inst in [ # type: ignore + self.entities_vdb, + self.relationships_vdb, + self.chunk_entity_relation_graph, + ] + ] + ) def _get_content_summary(self, content: str, max_length: int = 100) -> str: """Get summary of document content @@ -1259,7 +1254,7 @@ class LightRAG: """ return await self.doc_status.get_status_counts() - async def adelete_by_doc_id(self, doc_id: str): + async def adelete_by_doc_id(self, doc_id: str) -> None: """Delete a document and all its related data Args: @@ -1276,6 +1271,9 @@ class LightRAG: # 2. Get all related chunks chunks = await self.text_chunks.get_by_id(doc_id) + if not chunks: + return + chunk_ids = list(chunks.keys()) logger.debug(f"Found {len(chunk_ids)} chunks to delete") @@ -1446,13 +1444,9 @@ class LightRAG: except Exception as e: logger.error(f"Error while deleting document {doc_id}: {e}") - def delete_by_doc_id(self, doc_id: str): - """Synchronous version of adelete""" - return asyncio.run(self.adelete_by_doc_id(doc_id)) - async def get_entity_info( self, entity_name: str, include_vector_data: bool = False - ): + ) -> dict[str, str | None | dict[str, str]]: """Get detailed information of an entity Args: @@ -1472,7 +1466,7 @@ class LightRAG: node_data = await self.chunk_entity_relation_graph.get_node(entity_name) source_id = node_data.get("source_id") if node_data else None - result = { + result: dict[str, str | None | dict[str, str]] = { "entity_name": entity_name, "source_id": source_id, "graph_data": node_data, @@ -1486,21 +1480,6 @@ class LightRAG: return result - def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False): - """Synchronous version of getting entity information - - Args: - entity_name: Entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - """ - try: - import tracemalloc - - tracemalloc.start() - return asyncio.run(self.get_entity_info(entity_name, include_vector_data)) - finally: - tracemalloc.stop() - async def get_relation_info( self, src_entity: str, tgt_entity: str, include_vector_data: bool = False ): @@ -1528,7 +1507,7 @@ class LightRAG: ) source_id = edge_data.get("source_id") if edge_data else None - result = { + result: dict[str, str | None | dict[str, str]] = { "src_entity": src_entity, "tgt_entity": tgt_entity, "source_id": source_id, @@ -1542,23 +1521,3 @@ class LightRAG: result["vector_data"] = vector_data[0] if vector_data else None return result - - def get_relation_info_sync( - self, src_entity: str, tgt_entity: str, include_vector_data: bool = False - ): - """Synchronous version of getting relationship information - - Args: - src_entity: Source entity name (no need for quotes) - tgt_entity: Target entity name (no need for quotes) - include_vector_data: Whether to include data from the vector database - """ - try: - import tracemalloc - - tracemalloc.start() - return asyncio.run( - self.get_relation_info(src_entity, tgt_entity, include_vector_data) - ) - finally: - tracemalloc.stop() diff --git a/lightrag/llm.py b/lightrag/llm.py index 3ca17725..e5f98cf8 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,4 +1,6 @@ -from typing import List, Dict, Callable, Any +from __future__ import annotations + +from typing import Callable, Any from pydantic import BaseModel, Field @@ -23,7 +25,7 @@ class Model(BaseModel): ..., description="A function that generates the response from the llm. The response must be a string", ) - kwargs: Dict[str, Any] = Field( + kwargs: dict[str, Any] = Field( ..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc", ) @@ -57,7 +59,7 @@ class MultiModel: ``` """ - def __init__(self, models: List[Model]): + def __init__(self, models: list[Model]): self._models = models self._current_model = 0 @@ -66,7 +68,11 @@ class MultiModel: return self._models[self._current_model] async def llm_model_func( - self, prompt, system_prompt=None, history_messages=[], **kwargs + self, + prompt: str, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] = [], + **kwargs: Any, ) -> str: kwargs.pop("model", None) # stop from overwriting the custom model name kwargs.pop("keyword_extraction", None) diff --git a/lightrag/namespace.py b/lightrag/namespace.py index ba8e3072..77e04c9e 100644 --- a/lightrag/namespace.py +++ b/lightrag/namespace.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterable diff --git a/lightrag/operate.py b/lightrag/operate.py index 8cf77f57..5c80a4d1 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import json import re from tqdm.asyncio import tqdm as tqdm_async -from typing import Any, Union +from typing import Any, AsyncIterator from collections import Counter, defaultdict from .utils import ( logger, @@ -36,7 +38,7 @@ import time def chunking_by_token_size( content: str, - split_by_character: Union[str, None] = None, + split_by_character: str | None = None, split_by_character_only: bool = False, overlap_token_size: int = 128, max_token_size: int = 1024, @@ -335,9 +337,9 @@ async def extract_entities( knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - global_config: dict, - llm_response_cache: BaseKVStorage = None, -) -> Union[BaseGraphStorage, None]: + global_config: dict[str, str], + llm_response_cache: BaseKVStorage | None = None, +) -> BaseGraphStorage | None: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ @@ -603,15 +605,15 @@ async def extract_entities( async def kg_query( - query, + query: str, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, - prompt: str = "", + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + prompt: str | None = None, ) -> str: # Handle cache use_model_func = global_config["llm_model_func"] @@ -721,8 +723,8 @@ async def kg_query( async def extract_keywords_only( text: str, param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. @@ -818,9 +820,9 @@ async def mix_kg_vector_query( chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, -) -> str: + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, +) -> str | AsyncIterator[str]: """ Hybrid retrieval implementation combining knowledge graph and vector search. @@ -1539,13 +1541,13 @@ def combine_contexts(entities, relationships, sources): async def naive_query( - query, + query: str, chunks_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, -): + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, +) -> str | AsyncIterator[str]: # Handle cache use_model_func = global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query") @@ -1646,9 +1648,9 @@ async def kg_query_with_keywords( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict, - hashing_kv: BaseKVStorage = None, -) -> str: + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, +) -> str | AsyncIterator[str]: """ Refactored kg_query that does NOT extract keywords by itself. It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 160663d9..f4f5e38a 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + GRAPH_FIELD_SEP = "" PROMPTS = {} diff --git a/lightrag/types.py b/lightrag/types.py index d2670ddc..5e3d2948 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -1,16 +1,18 @@ +from __future__ import annotations + from pydantic import BaseModel -from typing import List, Dict, Any, Optional +from typing import Any, Optional class GPTKeywordExtractionFormat(BaseModel): - high_level_keywords: List[str] - low_level_keywords: List[str] + high_level_keywords: list[str] + low_level_keywords: list[str] class KnowledgeGraphNode(BaseModel): id: str - labels: List[str] - properties: Dict[str, Any] # anything else goes here + labels: list[str] + properties: dict[str, Any] # anything else goes here class KnowledgeGraphEdge(BaseModel): @@ -18,9 +20,9 @@ class KnowledgeGraphEdge(BaseModel): type: Optional[str] source: str # id of source node target: str # id of target node - properties: Dict[str, Any] # anything else goes here + properties: dict[str, Any] # anything else goes here class KnowledgeGraph(BaseModel): - nodes: List[KnowledgeGraphNode] = [] - edges: List[KnowledgeGraphEdge] = [] + nodes: list[KnowledgeGraphNode] = [] + edges: list[KnowledgeGraphEdge] = [] diff --git a/lightrag/utils.py b/lightrag/utils.py index 9df325ca..c8786e7b 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import html import io @@ -9,7 +11,7 @@ import re from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union, List, Optional +from typing import Any, Callable import xml.etree.ElementTree as ET import bs4 @@ -67,12 +69,12 @@ class EmbeddingFunc: @dataclass class ReasoningResponse: - reasoning_content: str + reasoning_content: str | None response_content: str tag: str -def locate_json_string_body_from_string(content: str) -> Union[str, None]: +def locate_json_string_body_from_string(content: str) -> str | None: """Locate the JSON string body from a string""" try: maybe_json_str = re.search(r"{.*}", content, re.DOTALL) @@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]: raise e from None -def compute_args_hash(*args, cache_type: str = None) -> str: +def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: """Compute a hash for the given arguments. Args: *args: Arguments to hash @@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str: return hashlib.md5(args_str.encode()).hexdigest() -def compute_mdhash_id(content, prefix: str = ""): +def compute_mdhash_id(content: str, prefix: str = "") -> str: + """ + Compute a unique ID for a given content string. + + The ID is a combination of the given prefix and the MD5 hash of the content string. + """ return prefix + md5(content.encode()).hexdigest() @@ -215,11 +222,13 @@ def clean_str(input: Any) -> str: return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) -def is_float_regex(value): +def is_float_regex(value: str) -> bool: return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) -def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): +def truncate_list_by_token_size( + list_data: list[Any], key: Callable[[Any], str], max_token_size: int +) -> list[int]: """Truncate a list of data by token size""" if max_token_size <= 0: return [] @@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data -def list_of_list_to_csv(data: List[List[str]]) -> str: +def list_of_list_to_csv(data: list[list[str]]) -> str: output = io.StringIO() writer = csv.writer( output, @@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str: return output.getvalue() -def csv_string_to_list(csv_string: str) -> List[List[str]]: +def csv_string_to_list(csv_string: str) -> list[list[str]]: # Clean the string by removing NUL characters cleaned_string = csv_string.replace("\0", "") @@ -329,7 +338,7 @@ def xml_to_json(xml_file): return None -def process_combine_contexts(hl, ll): +def process_combine_contexts(hl: str, ll: str): header = None list_hl = csv_string_to_list(hl.strip()) list_ll = csv_string_to_list(ll.strip()) @@ -375,7 +384,7 @@ async def get_best_cached_response( llm_func=None, original_prompt=None, cache_type=None, -) -> Union[str, None]: +) -> str | None: logger.debug( f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}" ) @@ -479,7 +488,7 @@ def cosine_similarity(v1, v2): return dot_product / (norm1 * norm2) -def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple: +def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple: """Quantize embedding to specified bits""" # Convert list to numpy array if needed if isinstance(embedding, list): @@ -570,9 +579,9 @@ class CacheData: args_hash: str content: str prompt: str - quantized: Optional[np.ndarray] = None - min_val: Optional[float] = None - max_val: Optional[float] = None + quantized: np.ndarray | None = None + min_val: float | None = None + max_val: float | None = None mode: str = "default" cache_type: str = "query" @@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool: return False -def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str: +def get_conversation_turns( + conversation_history: list[dict[str, Any]], num_turns: int +) -> str: """ Process conversation history to get the specified number of complete turns. @@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> Formatted string of the conversation history """ # Group messages into turns - turns = [] - messages = [] + turns: list[list[dict[str, Any]]] = [] + messages: list[dict[str, Any]] = [] # First, filter out keyword extraction messages for msg in conversation_history: @@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> turns = turns[-num_turns:] # Format the turns into a string - formatted_turns = [] + formatted_turns: list[str] = [] for turn in turns: formatted_turns.extend( [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]