Merge branch 'main' into add-env-settings

This commit is contained in:
yangdx
2025-02-16 22:34:39 +08:00
25 changed files with 1086 additions and 793 deletions

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import asyncio
import os
import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, Callable, Optional, Type, Union, cast
from typing import Any, AsyncIterator, Callable, Iterator, cast
from .base import (
BaseGraphStorage,
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
@@ -91,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = {
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
@@ -160,6 +164,7 @@ STORAGES = {
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
@@ -176,7 +181,7 @@ STORAGES = {
}
def lazy_external_import(module_name: str, class_name: str):
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args, **kwargs):
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
@@ -302,7 +307,7 @@ class LightRAG:
- random_seed: Seed value for reproducibility.
"""
embedding_func: EmbeddingFunc = None
embedding_func: EmbeddingFunc | None = None
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32
@@ -312,7 +317,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
llm_model_func: callable = None
llm_model_func: Callable[..., object] | None = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -342,10 +347,8 @@ class LightRAG:
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
# extension
addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
)
@@ -354,7 +357,7 @@ class LightRAG:
chunking_func: Callable[
[
str,
Optional[str],
str | None,
bool,
int,
int,
@@ -443,77 +446,74 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
# Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage)
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
) # type: ignore
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
) # type: ignore
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
)
self.key_string_value_json_storage_cls = partial(
) # type: ignore
self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache = self.key_string_value_json_storage_cls(
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
embedding_func=self.embedding_func,
)
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
embedding_func=self.embedding_func,
)
self.entities_vdb = self.vector_db_storage_cls(
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = self.vector_db_storage_cls(
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
@@ -527,13 +527,12 @@ class LightRAG:
embedding_func=None,
)
# What's for, Is this nessisary ?
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
@@ -542,7 +541,7 @@ class LightRAG:
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func,
self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
@@ -559,68 +558,45 @@ class LightRAG:
node_label=nodel_label, max_depth=max_depth
)
def _get_storage_class(self, storage_name: str) -> dict:
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
def set_storage_client(self, db_client):
# Deprecated, seting correct value to *_storage of LightRAG insteaded
# Inject db to storage implementation (only tested on Oracle Database)
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client
storage.db = db_client
def insert(
self,
string_or_strings: Union[str, list[str]],
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Sync Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
self.ainsert(input, split_by_character, split_by_character_only)
)
async def ainsert(
self,
string_or_strings: Union[str, list[str]],
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Async Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
await self.apipeline_enqueue_documents(string_or_strings)
await self.apipeline_enqueue_documents(input)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
@@ -677,7 +653,7 @@ class LightRAG:
if update_storage:
await self._insert_done()
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
async def apipeline_enqueue_documents(self, input: str | list[str]):
"""
Pipeline for Processing Documents
@@ -686,11 +662,11 @@ class LightRAG:
3. Filter out already processed documents
4. Enqueue document in status
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
if isinstance(input, str):
input = [input]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
unique_contents = list(set(doc.strip() for doc in input))
# 2. Generate document IDs and initial status
new_docs: dict[str, Any] = {
@@ -857,32 +833,32 @@ class LightRAG:
raise e
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
tasks = [
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]
if storage_inst is not None
]
await asyncio.gather(*tasks)
def insert_custom_kg(self, custom_kg: dict):
def insert_custom_kg(self, custom_kg: dict[str, Any]):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict):
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
update_storage = False
try:
# Insert chunks into vector storage
all_chunks_data = {}
chunk_to_source_map = {}
for chunk_data in custom_kg.get("chunks", []):
all_chunks_data: dict[str, dict[str, str]] = {}
chunk_to_source_map: dict[str, str] = {}
for chunk_data in custom_kg.get("chunks", {}):
chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -892,13 +868,13 @@ class LightRAG:
chunk_to_source_map[source_id] = chunk_id
update_storage = True
if self.chunks_vdb is not None and all_chunks_data:
if all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data)
if self.text_chunks is not None and all_chunks_data:
if all_chunks_data:
await self.text_chunks.upsert(all_chunks_data)
# Insert entities into knowledge graph
all_entities_data = []
all_entities_data: list[dict[str, str]] = []
for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -914,7 +890,7 @@ class LightRAG:
)
# Prepare node data
node_data = {
node_data: dict[str, str] = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
@@ -928,7 +904,7 @@ class LightRAG:
update_storage = True
# Insert relationships into knowledge graph
all_relationships_data = []
all_relationships_data: list[dict[str, str]] = []
for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -970,7 +946,7 @@ class LightRAG:
"source_id": source_id,
},
)
edge_data = {
edge_data: dict[str, str] = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
@@ -980,41 +956,68 @@ class LightRAG:
update_storage = True
# Insert entities into vector storage if needed
if self.entities_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
await self.entities_vdb.upsert(data_for_vdb)
for dp in all_entities_data
}
await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
if self.relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
await self.relationships_vdb.upsert(data_for_vdb)
for dp in all_relationships_data
}
await self.relationships_vdb.upsert(data_for_vdb)
finally:
if update_storage:
await self._insert_done()
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
def query(
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
) -> str | Iterator[str]:
"""
Perform a sync query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, prompt, param))
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
async def aquery(
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
):
self,
query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None,
) -> str | AsyncIterator[str]:
"""
Perform a async query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query,
@@ -1094,7 +1097,7 @@ class LightRAG:
async def aquery_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
) -> str | AsyncIterator[str]:
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
@@ -1117,8 +1120,8 @@ class LightRAG:
),
)
param.hl_keywords = (hl_keywords,)
param.ll_keywords = (ll_keywords,)
param.hl_keywords = hl_keywords
param.ll_keywords = ll_keywords
# ---------------------
# STEP 2: Final Query Logic
@@ -1146,7 +1149,7 @@ class LightRAG:
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_funcne,
embedding_func=self.embedding_func,
),
)
elif param.mode == "naive":
@@ -1195,12 +1198,7 @@ class LightRAG:
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await self.llm_response_cache.index_done_callback()
def delete_by_entity(self, entity_name: str):
loop = always_get_an_event_loop()
@@ -1222,16 +1220,16 @@ class LightRAG:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(self):
tasks = []
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]
]
)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
@@ -1256,7 +1254,7 @@ class LightRAG:
"""
return await self.doc_status.get_status_counts()
async def adelete_by_doc_id(self, doc_id: str):
async def adelete_by_doc_id(self, doc_id: str) -> None:
"""Delete a document and all its related data
Args:
@@ -1273,6 +1271,9 @@ class LightRAG:
# 2. Get all related chunks
chunks = await self.text_chunks.get_by_id(doc_id)
if not chunks:
return
chunk_ids = list(chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
@@ -1443,13 +1444,9 @@ class LightRAG:
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
def delete_by_doc_id(self, doc_id: str):
"""Synchronous version of adelete"""
return asyncio.run(self.adelete_by_doc_id(doc_id))
async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False
):
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity
Args:
@@ -1469,7 +1466,7 @@ class LightRAG:
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
source_id = node_data.get("source_id") if node_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
@@ -1483,21 +1480,6 @@ class LightRAG:
return result
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
"""Synchronous version of getting entity information
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
tracemalloc.start()
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
finally:
tracemalloc.stop()
async def get_relation_info(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
@@ -1525,7 +1507,7 @@ class LightRAG:
)
source_id = edge_data.get("source_id") if edge_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,
@@ -1539,23 +1521,3 @@ class LightRAG:
result["vector_data"] = vector_data[0] if vector_data else None
return result
def get_relation_info_sync(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
"""Synchronous version of getting relationship information
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
tracemalloc.start()
return asyncio.run(
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
)
finally:
tracemalloc.stop()