cleaning the mess
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import configparser
|
||||
@@ -91,7 +93,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS = {
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
@@ -176,7 +178,7 @@ STORAGES = {
|
||||
}
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str):
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
@@ -185,7 +187,7 @@ def lazy_external_import(module_name: str, class_name: str):
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args, **kwargs):
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
@@ -302,7 +304,7 @@ class LightRAG:
|
||||
- random_seed: Seed value for reproducibility.
|
||||
"""
|
||||
|
||||
embedding_func: EmbeddingFunc = None
|
||||
embedding_func: Union[EmbeddingFunc, None] = None
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_batch_num: int = 32
|
||||
@@ -312,7 +314,7 @@ class LightRAG:
|
||||
"""Maximum number of concurrent embedding function calls."""
|
||||
|
||||
# LLM Configuration
|
||||
llm_model_func: callable = None
|
||||
llm_model_func: Union[Callable[..., object], None] = None
|
||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@@ -443,77 +445,77 @@ class LightRAG:
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init LLM
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
# Initialize all storages
|
||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore
|
||||
self._get_storage_class(self.kv_storage)
|
||||
)
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore
|
||||
self.vector_storage
|
||||
)
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore
|
||||
self.graph_storage
|
||||
)
|
||||
|
||||
self.key_string_value_json_storage_cls = partial(
|
||||
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||
self.key_string_value_json_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.vector_db_storage_cls = partial(
|
||||
self.vector_db_storage_cls = partial( # type: ignore
|
||||
self.vector_db_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.graph_storage_cls = partial(
|
||||
self.graph_storage_cls = partial( # type: ignore
|
||||
self.graph_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
self.entities_vdb = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
self.relationships_vdb = self.vector_db_storage_cls(
|
||||
self.relationships_vdb = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||
),
|
||||
@@ -533,7 +535,7 @@ class LightRAG:
|
||||
):
|
||||
hashing_kv = self.llm_response_cache
|
||||
else:
|
||||
hashing_kv = self.key_string_value_json_storage_cls(
|
||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
@@ -542,7 +544,7 @@ class LightRAG:
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
self.llm_model_func,
|
||||
self.llm_model_func, # type: ignore
|
||||
hashing_kv=hashing_kv,
|
||||
**self.llm_model_kwargs,
|
||||
)
|
||||
@@ -559,68 +561,45 @@ class LightRAG:
|
||||
node_label=nodel_label, max_depth=max_depth
|
||||
)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> dict:
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
import_path = STORAGES[storage_name]
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
return storage_class
|
||||
|
||||
def set_storage_client(self, db_client):
|
||||
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
||||
# Inject db to storage implementation (only tested on Oracle Database)
|
||||
for storage in [
|
||||
self.vector_db_storage_cls,
|
||||
self.graph_storage_cls,
|
||||
self.doc_status,
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.key_string_value_json_storage_cls,
|
||||
self.chunks_vdb,
|
||||
self.relationships_vdb,
|
||||
self.entities_vdb,
|
||||
self.graph_storage_cls,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.llm_response_cache,
|
||||
]:
|
||||
# set client
|
||||
storage.db = db_client
|
||||
|
||||
def insert(
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
"""Sync Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
input: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
||||
self.ainsert(input, split_by_character, split_by_character_only)
|
||||
)
|
||||
|
||||
async def ainsert(
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
"""Async Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
input: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
await self.apipeline_enqueue_documents(string_or_strings)
|
||||
await self.apipeline_enqueue_documents(input)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
@@ -677,7 +656,7 @@ class LightRAG:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
|
||||
async def apipeline_enqueue_documents(self, input: str | list[str]):
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
||||
@@ -686,11 +665,11 @@ class LightRAG:
|
||||
3. Filter out already processed documents
|
||||
4. Enqueue document in status
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
# 1. Remove duplicate contents from the list
|
||||
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
||||
unique_contents = list(set(doc.strip() for doc in input))
|
||||
|
||||
# 2. Generate document IDs and initial status
|
||||
new_docs: dict[str, Any] = {
|
||||
@@ -872,11 +851,11 @@ class LightRAG:
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def insert_custom_kg(self, custom_kg: dict):
|
||||
def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||
|
||||
async def ainsert_custom_kg(self, custom_kg: dict):
|
||||
async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
|
||||
update_storage = False
|
||||
try:
|
||||
# Insert chunks into vector storage
|
||||
|
Reference in New Issue
Block a user