cleaning the mess

This commit is contained in:
Yannick Stephan
2025-02-14 22:50:49 +01:00
parent 4d58ff8bb4
commit 28c8443ff2

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import os
import configparser
@@ -91,7 +93,7 @@ STORAGE_IMPLEMENTATIONS = {
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = {
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
@@ -176,7 +178,7 @@ STORAGES = {
}
def lazy_external_import(module_name: str, class_name: str):
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
@@ -185,7 +187,7 @@ def lazy_external_import(module_name: str, class_name: str):
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args, **kwargs):
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
@@ -302,7 +304,7 @@ class LightRAG:
- random_seed: Seed value for reproducibility.
"""
embedding_func: EmbeddingFunc = None
embedding_func: Union[EmbeddingFunc, None] = None
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32
@@ -312,7 +314,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
llm_model_func: callable = None
llm_model_func: Union[Callable[..., object], None] = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -443,77 +445,77 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
# Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore
self._get_storage_class(self.kv_storage)
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore
self.vector_storage
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore
self.graph_storage
)
self.key_string_value_json_storage_cls = partial(
self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache = self.key_string_value_json_storage_cls(
self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
embedding_func=self.embedding_func,
)
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
embedding_func=self.embedding_func,
)
self.entities_vdb = self.vector_db_storage_cls(
self.entities_vdb = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = self.vector_db_storage_cls(
self.relationships_vdb = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
@@ -533,7 +535,7 @@ class LightRAG:
):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
@@ -542,7 +544,7 @@ class LightRAG:
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func,
self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
@@ -559,68 +561,45 @@ class LightRAG:
node_label=nodel_label, max_depth=max_depth
)
def _get_storage_class(self, storage_name: str) -> dict:
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
def set_storage_client(self, db_client):
# Deprecated, seting correct value to *_storage of LightRAG insteaded
# Inject db to storage implementation (only tested on Oracle Database)
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client
storage.db = db_client
def insert(
self,
string_or_strings: Union[str, list[str]],
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Sync Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
self.ainsert(input, split_by_character, split_by_character_only)
)
async def ainsert(
self,
string_or_strings: Union[str, list[str]],
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Async Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
await self.apipeline_enqueue_documents(string_or_strings)
await self.apipeline_enqueue_documents(input)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
@@ -677,7 +656,7 @@ class LightRAG:
if update_storage:
await self._insert_done()
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
async def apipeline_enqueue_documents(self, input: str | list[str]):
"""
Pipeline for Processing Documents
@@ -686,11 +665,11 @@ class LightRAG:
3. Filter out already processed documents
4. Enqueue document in status
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
if isinstance(input, str):
input = [input]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
unique_contents = list(set(doc.strip() for doc in input))
# 2. Generate document IDs and initial status
new_docs: dict[str, Any] = {
@@ -872,11 +851,11 @@ class LightRAG:
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
def insert_custom_kg(self, custom_kg: dict):
def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict):
async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
update_storage = False
try:
# Insert chunks into vector storage