Merge branch 'main' into add-env-settings
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import configparser
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Type, Union, cast
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
||||
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
"FaissVectorDBStorage",
|
||||
"QdrantVectorDBStorage",
|
||||
"OracleVectorDBStorage",
|
||||
"MongoVectorDBStorage",
|
||||
],
|
||||
"required_methods": ["query", "upsert"],
|
||||
},
|
||||
@@ -91,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS = {
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
|
||||
"ORACLE_PASSWORD",
|
||||
"ORACLE_CONFIG_DIR",
|
||||
],
|
||||
"MongoVectorDBStorage": [],
|
||||
# Document Status Storage Implementations
|
||||
"JsonDocStatusStorage": [],
|
||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||
@@ -160,6 +164,7 @@ STORAGES = {
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoDocStatusStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
@@ -176,7 +181,7 @@ STORAGES = {
|
||||
}
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str):
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args, **kwargs):
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
@@ -302,7 +307,7 @@ class LightRAG:
|
||||
- random_seed: Seed value for reproducibility.
|
||||
"""
|
||||
|
||||
embedding_func: EmbeddingFunc = None
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_batch_num: int = 32
|
||||
@@ -312,7 +317,7 @@ class LightRAG:
|
||||
"""Maximum number of concurrent embedding function calls."""
|
||||
|
||||
# LLM Configuration
|
||||
llm_model_func: callable = None
|
||||
llm_model_func: Callable[..., object] | None = None
|
||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@@ -342,10 +347,8 @@ class LightRAG:
|
||||
|
||||
# Extensions
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
|
||||
# extension
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||
convert_response_to_json
|
||||
)
|
||||
@@ -354,7 +357,7 @@ class LightRAG:
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
Optional[str],
|
||||
str | None,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
@@ -443,77 +446,74 @@ class LightRAG:
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init LLM
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
# Initialize all storages
|
||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
||||
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||
self._get_storage_class(self.kv_storage)
|
||||
)
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
|
||||
) # type: ignore
|
||||
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
||||
self.vector_storage
|
||||
)
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
|
||||
) # type: ignore
|
||||
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
||||
self.graph_storage
|
||||
)
|
||||
|
||||
self.key_string_value_json_storage_cls = partial(
|
||||
) # type: ignore
|
||||
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||
self.key_string_value_json_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.vector_db_storage_cls = partial(
|
||||
self.vector_db_storage_cls = partial( # type: ignore
|
||||
self.vector_db_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.graph_storage_cls = partial(
|
||||
self.graph_storage_cls = partial( # type: ignore
|
||||
self.graph_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
self.relationships_vdb = self.vector_db_storage_cls(
|
||||
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||
),
|
||||
@@ -527,13 +527,12 @@ class LightRAG:
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
# What's for, Is this nessisary ?
|
||||
if self.llm_response_cache and hasattr(
|
||||
self.llm_response_cache, "global_config"
|
||||
):
|
||||
hashing_kv = self.llm_response_cache
|
||||
else:
|
||||
hashing_kv = self.key_string_value_json_storage_cls(
|
||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
@@ -542,7 +541,7 @@ class LightRAG:
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
self.llm_model_func,
|
||||
self.llm_model_func, # type: ignore
|
||||
hashing_kv=hashing_kv,
|
||||
**self.llm_model_kwargs,
|
||||
)
|
||||
@@ -559,68 +558,45 @@ class LightRAG:
|
||||
node_label=nodel_label, max_depth=max_depth
|
||||
)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> dict:
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
import_path = STORAGES[storage_name]
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
return storage_class
|
||||
|
||||
def set_storage_client(self, db_client):
|
||||
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
||||
# Inject db to storage implementation (only tested on Oracle Database)
|
||||
for storage in [
|
||||
self.vector_db_storage_cls,
|
||||
self.graph_storage_cls,
|
||||
self.doc_status,
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.key_string_value_json_storage_cls,
|
||||
self.chunks_vdb,
|
||||
self.relationships_vdb,
|
||||
self.entities_vdb,
|
||||
self.graph_storage_cls,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.llm_response_cache,
|
||||
]:
|
||||
# set client
|
||||
storage.db = db_client
|
||||
|
||||
def insert(
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
"""Sync Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
input: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
||||
self.ainsert(input, split_by_character, split_by_character_only)
|
||||
)
|
||||
|
||||
async def ainsert(
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
"""Async Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
input: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
await self.apipeline_enqueue_documents(string_or_strings)
|
||||
await self.apipeline_enqueue_documents(input)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
@@ -677,7 +653,7 @@ class LightRAG:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
|
||||
async def apipeline_enqueue_documents(self, input: str | list[str]):
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
||||
@@ -686,11 +662,11 @@ class LightRAG:
|
||||
3. Filter out already processed documents
|
||||
4. Enqueue document in status
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
# 1. Remove duplicate contents from the list
|
||||
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
||||
unique_contents = list(set(doc.strip() for doc in input))
|
||||
|
||||
# 2. Generate document IDs and initial status
|
||||
new_docs: dict[str, Any] = {
|
||||
@@ -857,32 +833,32 @@ class LightRAG:
|
||||
raise e
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
tasks = [
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]
|
||||
if storage_inst is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def insert_custom_kg(self, custom_kg: dict):
|
||||
def insert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||
|
||||
async def ainsert_custom_kg(self, custom_kg: dict):
|
||||
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||
update_storage = False
|
||||
try:
|
||||
# Insert chunks into vector storage
|
||||
all_chunks_data = {}
|
||||
chunk_to_source_map = {}
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||
chunk_to_source_map: dict[str, str] = {}
|
||||
for chunk_data in custom_kg.get("chunks", {}):
|
||||
chunk_content = chunk_data["content"]
|
||||
source_id = chunk_data["source_id"]
|
||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||
@@ -892,13 +868,13 @@ class LightRAG:
|
||||
chunk_to_source_map[source_id] = chunk_id
|
||||
update_storage = True
|
||||
|
||||
if self.chunks_vdb is not None and all_chunks_data:
|
||||
if all_chunks_data:
|
||||
await self.chunks_vdb.upsert(all_chunks_data)
|
||||
if self.text_chunks is not None and all_chunks_data:
|
||||
if all_chunks_data:
|
||||
await self.text_chunks.upsert(all_chunks_data)
|
||||
|
||||
# Insert entities into knowledge graph
|
||||
all_entities_data = []
|
||||
all_entities_data: list[dict[str, str]] = []
|
||||
for entity_data in custom_kg.get("entities", []):
|
||||
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||
@@ -914,7 +890,7 @@ class LightRAG:
|
||||
)
|
||||
|
||||
# Prepare node data
|
||||
node_data = {
|
||||
node_data: dict[str, str] = {
|
||||
"entity_type": entity_type,
|
||||
"description": description,
|
||||
"source_id": source_id,
|
||||
@@ -928,7 +904,7 @@ class LightRAG:
|
||||
update_storage = True
|
||||
|
||||
# Insert relationships into knowledge graph
|
||||
all_relationships_data = []
|
||||
all_relationships_data: list[dict[str, str]] = []
|
||||
for relationship_data in custom_kg.get("relationships", []):
|
||||
src_id = f'"{relationship_data["src_id"].upper()}"'
|
||||
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
||||
@@ -970,7 +946,7 @@ class LightRAG:
|
||||
"source_id": source_id,
|
||||
},
|
||||
)
|
||||
edge_data = {
|
||||
edge_data: dict[str, str] = {
|
||||
"src_id": src_id,
|
||||
"tgt_id": tgt_id,
|
||||
"description": description,
|
||||
@@ -980,41 +956,68 @@ class LightRAG:
|
||||
update_storage = True
|
||||
|
||||
# Insert entities into vector storage if needed
|
||||
if self.entities_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
}
|
||||
for dp in all_entities_data
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
}
|
||||
await self.entities_vdb.upsert(data_for_vdb)
|
||||
for dp in all_entities_data
|
||||
}
|
||||
await self.entities_vdb.upsert(data_for_vdb)
|
||||
|
||||
# Insert relationships into vector storage if needed
|
||||
if self.relationships_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
await self.relationships_vdb.upsert(data_for_vdb)
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
await self.relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
finally:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
|
||||
def query(
|
||||
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
||||
) -> str | Iterator[str]:
|
||||
"""
|
||||
Perform a sync query.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
str: The result of the query execution.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, prompt, param))
|
||||
|
||||
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
||||
|
||||
async def aquery(
|
||||
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
|
||||
):
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Perform a async query.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
str: The result of the query execution.
|
||||
"""
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
response = await kg_query(
|
||||
query,
|
||||
@@ -1094,7 +1097,7 @@ class LightRAG:
|
||||
|
||||
async def aquery_with_separate_keyword_extraction(
|
||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||
):
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||
@@ -1117,8 +1120,8 @@ class LightRAG:
|
||||
),
|
||||
)
|
||||
|
||||
param.hl_keywords = (hl_keywords,)
|
||||
param.ll_keywords = (ll_keywords,)
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# ---------------------
|
||||
# STEP 2: Final Query Logic
|
||||
@@ -1146,7 +1149,7 @@ class LightRAG:
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_funcne,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
@@ -1195,12 +1198,7 @@ class LightRAG:
|
||||
return response
|
||||
|
||||
async def _query_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [self.llm_response_cache]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
|
||||
def delete_by_entity(self, entity_name: str):
|
||||
loop = always_get_an_event_loop()
|
||||
@@ -1222,16 +1220,16 @@ class LightRAG:
|
||||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||||
|
||||
async def _delete_by_entity_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
||||
"""Get summary of document content
|
||||
@@ -1256,7 +1254,7 @@ class LightRAG:
|
||||
"""
|
||||
return await self.doc_status.get_status_counts()
|
||||
|
||||
async def adelete_by_doc_id(self, doc_id: str):
|
||||
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
||||
"""Delete a document and all its related data
|
||||
|
||||
Args:
|
||||
@@ -1273,6 +1271,9 @@ class LightRAG:
|
||||
|
||||
# 2. Get all related chunks
|
||||
chunks = await self.text_chunks.get_by_id(doc_id)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
chunk_ids = list(chunks.keys())
|
||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||
|
||||
@@ -1443,13 +1444,9 @@ class LightRAG:
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting document {doc_id}: {e}")
|
||||
|
||||
def delete_by_doc_id(self, doc_id: str):
|
||||
"""Synchronous version of adelete"""
|
||||
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
||||
|
||||
async def get_entity_info(
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
):
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of an entity
|
||||
|
||||
Args:
|
||||
@@ -1469,7 +1466,7 @@ class LightRAG:
|
||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||
source_id = node_data.get("source_id") if node_data else None
|
||||
|
||||
result = {
|
||||
result: dict[str, str | None | dict[str, str]] = {
|
||||
"entity_name": entity_name,
|
||||
"source_id": source_id,
|
||||
"graph_data": node_data,
|
||||
@@ -1483,21 +1480,6 @@ class LightRAG:
|
||||
|
||||
return result
|
||||
|
||||
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
|
||||
"""Synchronous version of getting entity information
|
||||
|
||||
Args:
|
||||
entity_name: Entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
"""
|
||||
try:
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
|
||||
finally:
|
||||
tracemalloc.stop()
|
||||
|
||||
async def get_relation_info(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
@@ -1525,7 +1507,7 @@ class LightRAG:
|
||||
)
|
||||
source_id = edge_data.get("source_id") if edge_data else None
|
||||
|
||||
result = {
|
||||
result: dict[str, str | None | dict[str, str]] = {
|
||||
"src_entity": src_entity,
|
||||
"tgt_entity": tgt_entity,
|
||||
"source_id": source_id,
|
||||
@@ -1539,23 +1521,3 @@ class LightRAG:
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
|
||||
return result
|
||||
|
||||
def get_relation_info_sync(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Synchronous version of getting relationship information
|
||||
|
||||
Args:
|
||||
src_entity: Source entity name (no need for quotes)
|
||||
tgt_entity: Target entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
"""
|
||||
try:
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
return asyncio.run(
|
||||
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
|
||||
)
|
||||
finally:
|
||||
tracemalloc.stop()
|
||||
|
Reference in New Issue
Block a user