cleaning the mess

This commit is contained in:
Yannick Stephan
2025-02-14 22:50:49 +01:00
parent 4d58ff8bb4
commit 28c8443ff2

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio import asyncio
import os import os
import configparser import configparser
@@ -91,7 +93,7 @@ STORAGE_IMPLEMENTATIONS = {
} }
# Storage implementation environment variable without default value # Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = { STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations # KV Storage Implementations
"JsonKVStorage": [], "JsonKVStorage": [],
"MongoKVStorage": [], "MongoKVStorage": [],
@@ -176,7 +178,7 @@ STORAGES = {
} }
def lazy_external_import(module_name: str, class_name: str): def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller.""" """Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package # Get the caller's module and package
import inspect import inspect
@@ -185,7 +187,7 @@ def lazy_external_import(module_name: str, class_name: str):
module = inspect.getmodule(caller_frame) module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None package = module.__package__ if module else None
def import_class(*args, **kwargs): def import_class(*args: Any, **kwargs: Any):
import importlib import importlib
module = importlib.import_module(module_name, package=package) module = importlib.import_module(module_name, package=package)
@@ -302,7 +304,7 @@ class LightRAG:
- random_seed: Seed value for reproducibility. - random_seed: Seed value for reproducibility.
""" """
embedding_func: EmbeddingFunc = None embedding_func: Union[EmbeddingFunc, None] = None
"""Function for computing text embeddings. Must be set before use.""" """Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32 embedding_batch_num: int = 32
@@ -312,7 +314,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls.""" """Maximum number of concurrent embedding function calls."""
# LLM Configuration # LLM Configuration
llm_model_func: callable = None llm_model_func: Union[Callable[..., object], None] = None
"""Function for interacting with the large language model (LLM). Must be set before use.""" """Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -443,77 +445,77 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM # Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func self.embedding_func
) )
# Initialize all storages # Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore
self._get_storage_class(self.kv_storage) self._get_storage_class(self.kv_storage)
) )
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore
self.vector_storage self.vector_storage
) )
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore
self.graph_storage self.graph_storage
) )
self.key_string_value_json_storage_cls = partial( self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config self.key_string_value_json_storage_cls, global_config=global_config
) )
self.vector_db_storage_cls = partial( self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config self.vector_db_storage_cls, global_config=global_config
) )
self.graph_storage_cls = partial( self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config self.graph_storage_cls, global_config=global_config
) )
# Initialize document status storage # Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache = self.key_string_value_json_storage_cls( self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.entities_vdb = self.vector_db_storage_cls( self.entities_vdb = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"entity_name"}, meta_fields={"entity_name"},
) )
self.relationships_vdb = self.vector_db_storage_cls( self.relationships_vdb = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}, meta_fields={"src_id", "tgt_id"},
) )
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
), ),
@@ -533,7 +535,7 @@ class LightRAG:
): ):
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
else: else:
hashing_kv = self.key_string_value_json_storage_cls( hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
@@ -542,7 +544,7 @@ class LightRAG:
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
self.llm_model_func, self.llm_model_func, # type: ignore
hashing_kv=hashing_kv, hashing_kv=hashing_kv,
**self.llm_model_kwargs, **self.llm_model_kwargs,
) )
@@ -559,68 +561,45 @@ class LightRAG:
node_label=nodel_label, max_depth=max_depth node_label=nodel_label, max_depth=max_depth
) )
def _get_storage_class(self, storage_name: str) -> dict: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name] import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name) storage_class = lazy_external_import(import_path, storage_name)
return storage_class return storage_class
def set_storage_client(self, db_client):
# Deprecated, seting correct value to *_storage of LightRAG insteaded
# Inject db to storage implementation (only tested on Oracle Database)
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client
storage.db = db_client
def insert( def insert(
self, self,
string_or_strings: Union[str, list[str]], input: str | list[str],
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
): ):
"""Sync Insert documents with checkpoint support """Sync Insert documents with checkpoint support
Args: Args:
string_or_strings: Single document string or list of document strings input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored. split_by_character is None, this parameter is ignored.
""" """
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only) self.ainsert(input, split_by_character, split_by_character_only)
) )
async def ainsert( async def ainsert(
self, self,
string_or_strings: Union[str, list[str]], input: str | list[str],
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
): ):
"""Async Insert documents with checkpoint support """Async Insert documents with checkpoint support
Args: Args:
string_or_strings: Single document string or list of document strings input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored. split_by_character is None, this parameter is ignored.
""" """
await self.apipeline_enqueue_documents(string_or_strings) await self.apipeline_enqueue_documents(input)
await self.apipeline_process_enqueue_documents( await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only split_by_character, split_by_character_only
) )
@@ -677,7 +656,7 @@ class LightRAG:
if update_storage: if update_storage:
await self._insert_done() await self._insert_done()
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]): async def apipeline_enqueue_documents(self, input: str | list[str]):
""" """
Pipeline for Processing Documents Pipeline for Processing Documents
@@ -686,11 +665,11 @@ class LightRAG:
3. Filter out already processed documents 3. Filter out already processed documents
4. Enqueue document in status 4. Enqueue document in status
""" """
if isinstance(string_or_strings, str): if isinstance(input, str):
string_or_strings = [string_or_strings] input = [input]
# 1. Remove duplicate contents from the list # 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings)) unique_contents = list(set(doc.strip() for doc in input))
# 2. Generate document IDs and initial status # 2. Generate document IDs and initial status
new_docs: dict[str, Any] = { new_docs: dict[str, Any] = {
@@ -872,11 +851,11 @@ class LightRAG:
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def insert_custom_kg(self, custom_kg: dict): def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict): async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
update_storage = False update_storage = False
try: try:
# Insert chunks into vector storage # Insert chunks into vector storage