cleaning the message and project no needed
This commit is contained in:
@@ -83,11 +83,11 @@ class StorageNameSpace:
|
||||
namespace: str
|
||||
global_config: dict[str, Any]
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
"""Commit the storage operations after indexing"""
|
||||
pass
|
||||
|
||||
async def query_done_callback(self):
|
||||
async def query_done_callback(self) -> None:
|
||||
"""Commit the storage operations after querying"""
|
||||
pass
|
||||
|
||||
|
@@ -6,7 +6,7 @@ import configparser
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Type, Union, cast
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -304,7 +304,7 @@ class LightRAG:
|
||||
- random_seed: Seed value for reproducibility.
|
||||
"""
|
||||
|
||||
embedding_func: Union[EmbeddingFunc, None] = None
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_batch_num: int = 32
|
||||
@@ -344,10 +344,8 @@ class LightRAG:
|
||||
|
||||
# Extensions
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
|
||||
# extension
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||
convert_response_to_json
|
||||
)
|
||||
@@ -445,77 +443,74 @@ class LightRAG:
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init LLM
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
# Initialize all storages
|
||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore
|
||||
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||
self._get_storage_class(self.kv_storage)
|
||||
)
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore
|
||||
) # type: ignore
|
||||
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
||||
self.vector_storage
|
||||
)
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore
|
||||
) # type: ignore
|
||||
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
||||
self.graph_storage
|
||||
)
|
||||
|
||||
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||
) # type: ignore
|
||||
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||
self.key_string_value_json_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.vector_db_storage_cls = partial( # type: ignore
|
||||
self.vector_db_storage_cls = partial( # type: ignore
|
||||
self.vector_db_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.graph_storage_cls = partial( # type: ignore
|
||||
self.graph_storage_cls = partial( # type: ignore
|
||||
self.graph_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore
|
||||
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.entities_vdb = self.vector_db_storage_cls( # type: ignore
|
||||
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
self.relationships_vdb = self.vector_db_storage_cls( # type: ignore
|
||||
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||
),
|
||||
@@ -535,16 +530,16 @@ class LightRAG:
|
||||
):
|
||||
hashing_kv = self.llm_response_cache
|
||||
else:
|
||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
self.llm_model_func, # type: ignore
|
||||
self.llm_model_func, # type: ignore
|
||||
hashing_kv=hashing_kv,
|
||||
**self.llm_model_kwargs,
|
||||
)
|
||||
@@ -836,32 +831,32 @@ class LightRAG:
|
||||
raise e
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
tasks = [
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]
|
||||
if storage_inst is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
|
||||
def insert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||
|
||||
async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
|
||||
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||
update_storage = False
|
||||
try:
|
||||
# Insert chunks into vector storage
|
||||
all_chunks_data = {}
|
||||
chunk_to_source_map = {}
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||
chunk_to_source_map: dict[str, str] = {}
|
||||
for chunk_data in custom_kg.get("chunks", {}):
|
||||
chunk_content = chunk_data["content"]
|
||||
source_id = chunk_data["source_id"]
|
||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||
@@ -871,13 +866,13 @@ class LightRAG:
|
||||
chunk_to_source_map[source_id] = chunk_id
|
||||
update_storage = True
|
||||
|
||||
if self.chunks_vdb is not None and all_chunks_data:
|
||||
if all_chunks_data:
|
||||
await self.chunks_vdb.upsert(all_chunks_data)
|
||||
if self.text_chunks is not None and all_chunks_data:
|
||||
if all_chunks_data:
|
||||
await self.text_chunks.upsert(all_chunks_data)
|
||||
|
||||
# Insert entities into knowledge graph
|
||||
all_entities_data = []
|
||||
all_entities_data: list[dict[str, str]] = []
|
||||
for entity_data in custom_kg.get("entities", []):
|
||||
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||
@@ -893,7 +888,7 @@ class LightRAG:
|
||||
)
|
||||
|
||||
# Prepare node data
|
||||
node_data = {
|
||||
node_data: dict[str, str] = {
|
||||
"entity_type": entity_type,
|
||||
"description": description,
|
||||
"source_id": source_id,
|
||||
@@ -907,7 +902,7 @@ class LightRAG:
|
||||
update_storage = True
|
||||
|
||||
# Insert relationships into knowledge graph
|
||||
all_relationships_data = []
|
||||
all_relationships_data: list[dict[str, str]] = []
|
||||
for relationship_data in custom_kg.get("relationships", []):
|
||||
src_id = f'"{relationship_data["src_id"].upper()}"'
|
||||
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
||||
@@ -949,7 +944,7 @@ class LightRAG:
|
||||
"source_id": source_id,
|
||||
},
|
||||
)
|
||||
edge_data = {
|
||||
edge_data: dict[str, str] = {
|
||||
"src_id": src_id,
|
||||
"tgt_id": tgt_id,
|
||||
"description": description,
|
||||
@@ -959,19 +954,17 @@ class LightRAG:
|
||||
update_storage = True
|
||||
|
||||
# Insert entities into vector storage if needed
|
||||
if self.entities_vdb is not None:
|
||||
data_for_vdb = {
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
}
|
||||
for dp in all_entities_data
|
||||
}
|
||||
await self.entities_vdb.upsert(data_for_vdb)
|
||||
await self.entities_vdb.upsert(data_for_vdb)
|
||||
|
||||
# Insert relationships into vector storage if needed
|
||||
if self.relationships_vdb is not None:
|
||||
data_for_vdb = {
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
@@ -982,18 +975,49 @@ class LightRAG:
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
await self.relationships_vdb.upsert(data_for_vdb)
|
||||
await self.relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
finally:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
prompt: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Perform a sync query.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
str: The result of the query execution.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, prompt, param))
|
||||
return loop.run_until_complete(self.aquery(query, param, prompt))
|
||||
|
||||
async def aquery(
|
||||
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
|
||||
):
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
prompt: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Perform a async query.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
str: The result of the query execution.
|
||||
"""
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
response = await kg_query(
|
||||
query,
|
||||
|
@@ -295,8 +295,8 @@ async def extract_entities(
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage = None,
|
||||
global_config: dict[str, str],
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
@@ -563,15 +563,15 @@ async def extract_entities(
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query,
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
prompt: str = "",
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> str:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -681,8 +681,8 @@ async def kg_query(
|
||||
async def extract_keywords_only(
|
||||
text: str,
|
||||
param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||
@@ -778,8 +778,8 @@ async def mix_kg_vector_query(
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
@@ -1499,12 +1499,12 @@ def combine_contexts(entities, relationships, sources):
|
||||
|
||||
|
||||
async def naive_query(
|
||||
query,
|
||||
query: str,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
):
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -1606,8 +1606,8 @@ async def kg_query_with_keywords(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Refactored kg_query that does NOT extract keywords by itself.
|
||||
|
@@ -128,7 +128,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
||||
return hashlib.md5(args_str.encode()).hexdigest()
|
||||
|
||||
|
||||
def compute_mdhash_id(content, prefix: str = ""):
|
||||
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||
"""
|
||||
Compute a unique ID for a given content string.
|
||||
|
||||
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
||||
"""
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user