From 28c8443ff2e3688ba244f126c892354511ad7c6b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 22:50:49 +0100 Subject: [PATCH] cleaning the mess --- lightrag/lightrag.py | 93 +++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9f74c917..fcea2c57 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import os import configparser @@ -91,7 +93,7 @@ STORAGE_IMPLEMENTATIONS = { } # Storage implementation environment variable without default value -STORAGE_ENV_REQUIREMENTS = { +STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { # KV Storage Implementations "JsonKVStorage": [], "MongoKVStorage": [], @@ -176,7 +178,7 @@ STORAGES = { } -def lazy_external_import(module_name: str, class_name: str): +def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: """Lazily import a class from an external module based on the package of the caller.""" # Get the caller's module and package import inspect @@ -185,7 +187,7 @@ def lazy_external_import(module_name: str, class_name: str): module = inspect.getmodule(caller_frame) package = module.__package__ if module else None - def import_class(*args, **kwargs): + def import_class(*args: Any, **kwargs: Any): import importlib module = importlib.import_module(module_name, package=package) @@ -302,7 +304,7 @@ class LightRAG: - random_seed: Seed value for reproducibility. """ - embedding_func: EmbeddingFunc = None + embedding_func: Union[EmbeddingFunc, None] = None """Function for computing text embeddings. Must be set before use.""" embedding_batch_num: int = 32 @@ -312,7 +314,7 @@ class LightRAG: """Maximum number of concurrent embedding function calls.""" # LLM Configuration - llm_model_func: callable = None + llm_model_func: Union[Callable[..., object], None] = None """Function for interacting with the large language model (LLM). Must be set before use.""" llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" @@ -443,77 +445,77 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init LLM - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( + self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore self.embedding_func ) # Initialize all storages - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore self._get_storage_class(self.kv_storage) ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore self.vector_storage ) - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore self.graph_storage ) - self.key_string_value_json_storage_cls = partial( + self.key_string_value_json_storage_cls = partial( # type: ignore self.key_string_value_json_storage_cls, global_config=global_config ) - self.vector_db_storage_cls = partial( + self.vector_db_storage_cls = partial( # type: ignore self.vector_db_storage_cls, global_config=global_config ) - self.graph_storage_cls = partial( + self.graph_storage_cls = partial( # type: ignore self.graph_storage_cls, global_config=global_config ) # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - self.llm_response_cache = self.key_string_value_json_storage_cls( + self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), embedding_func=self.embedding_func, ) - self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( + self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS ), embedding_func=self.embedding_func, ) - self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( + self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS ), embedding_func=self.embedding_func, ) - self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( + self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION ), embedding_func=self.embedding_func, ) - self.entities_vdb = self.vector_db_storage_cls( + self.entities_vdb = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES ), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) - self.relationships_vdb = self.vector_db_storage_cls( + self.relationships_vdb = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS ), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) - self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( + self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS ), @@ -533,7 +535,7 @@ class LightRAG: ): hashing_kv = self.llm_response_cache else: - hashing_kv = self.key_string_value_json_storage_cls( + hashing_kv = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), @@ -542,7 +544,7 @@ class LightRAG: self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( - self.llm_model_func, + self.llm_model_func, # type: ignore hashing_kv=hashing_kv, **self.llm_model_kwargs, ) @@ -559,68 +561,45 @@ class LightRAG: node_label=nodel_label, max_depth=max_depth ) - def _get_storage_class(self, storage_name: str) -> dict: + def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name) return storage_class - def set_storage_client(self, db_client): - # Deprecated, seting correct value to *_storage of LightRAG insteaded - # Inject db to storage implementation (only tested on Oracle Database) - for storage in [ - self.vector_db_storage_cls, - self.graph_storage_cls, - self.doc_status, - self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.key_string_value_json_storage_cls, - self.chunks_vdb, - self.relationships_vdb, - self.entities_vdb, - self.graph_storage_cls, - self.chunk_entity_relation_graph, - self.llm_response_cache, - ]: - # set client - storage.db = db_client - def insert( self, - string_or_strings: Union[str, list[str]], + input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ): """Sync Insert documents with checkpoint support Args: - string_or_strings: Single document string or list of document strings + input: Single document string or list of document strings split_by_character: if split_by_character is not None, split the string by character, if chunk longer than - chunk_size, split the sub chunk by token size. split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ loop = always_get_an_event_loop() return loop.run_until_complete( - self.ainsert(string_or_strings, split_by_character, split_by_character_only) + self.ainsert(input, split_by_character, split_by_character_only) ) async def ainsert( self, - string_or_strings: Union[str, list[str]], + input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ): """Async Insert documents with checkpoint support Args: - string_or_strings: Single document string or list of document strings + input: Single document string or list of document strings split_by_character: if split_by_character is not None, split the string by character, if chunk longer than - chunk_size, split the sub chunk by token size. split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ - await self.apipeline_enqueue_documents(string_or_strings) + await self.apipeline_enqueue_documents(input) await self.apipeline_process_enqueue_documents( split_by_character, split_by_character_only ) @@ -677,7 +656,7 @@ class LightRAG: if update_storage: await self._insert_done() - async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]): + async def apipeline_enqueue_documents(self, input: str | list[str]): """ Pipeline for Processing Documents @@ -686,11 +665,11 @@ class LightRAG: 3. Filter out already processed documents 4. Enqueue document in status """ - if isinstance(string_or_strings, str): - string_or_strings = [string_or_strings] + if isinstance(input, str): + input = [input] # 1. Remove duplicate contents from the list - unique_contents = list(set(doc.strip() for doc in string_or_strings)) + unique_contents = list(set(doc.strip() for doc in input)) # 2. Generate document IDs and initial status new_docs: dict[str, Any] = { @@ -872,11 +851,11 @@ class LightRAG: tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) - def insert_custom_kg(self, custom_kg: dict): + def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): loop = always_get_an_event_loop() return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) - async def ainsert_custom_kg(self, custom_kg: dict): + async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): update_storage = False try: # Insert chunks into vector storage