cleaning the message and project no needed

This commit is contained in:
Yannick Stephan
2025-02-14 23:31:27 +01:00
parent 28c8443ff2
commit 66f555677a
7 changed files with 129 additions and 441 deletions

View File

@@ -428,9 +428,9 @@ And using a routine to process news documents.
```python ```python
rag = LightRAG(..) rag = LightRAG(..)
await rag.apipeline_enqueue_documents(string_or_strings) await rag.apipeline_enqueue_documents(input)
# Your routine in loop # Your routine in loop
await rag.apipeline_process_enqueue_documents(string_or_strings) await rag.apipeline_process_enqueue_documents(input)
``` ```
### Separate Keyword Extraction ### Separate Keyword Extraction

View File

@@ -113,7 +113,24 @@ async def main():
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.set_storage_client(db_client=oracle_db)
for storage in [
rag.vector_db_storage_cls,
rag.graph_storage_cls,
rag.doc_status,
rag.full_docs,
rag.text_chunks,
rag.llm_response_cache,
rag.key_string_value_json_storage_cls,
rag.chunks_vdb,
rag.relationships_vdb,
rag.entities_vdb,
rag.graph_storage_cls,
rag.chunk_entity_relation_graph,
rag.llm_response_cache,
]:
# set client
storage.db = oracle_db
# Extract and Insert into LightRAG storage # Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:

View File

@@ -1,358 +0,0 @@
"""
OpenWebui Lightrag Integration Tool
==================================
This tool enables the integration and use of Lightrag within the OpenWebui environment,
providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
Author: ParisNeo (parisneoai@gmail.com)
Social:
- Twitter: @ParisNeo_AI
- Reddit: r/lollms
- Instagram: https://www.instagram.com/parisneo_ai/
License: Apache 2.0
Copyright (c) 2024-2025 ParisNeo
This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
For more information, visit: https://github.com/ParisNeo/lollms
Requirements:
- Python 3.8+
- OpenWebui
- Lightrag
"""
# Tool version
__version__ = "1.0.0"
__author__ = "ParisNeo"
__author_email__ = "parisneoai@gmail.com"
__description__ = "Lightrag integration for OpenWebui"
import requests
import json
from pydantic import BaseModel, Field
from typing import Callable, Any, Literal, Union, List, Tuple
class StatusEventEmitter:
def __init__(self, event_emitter: Callable[[dict], Any] = None):
self.event_emitter = event_emitter
async def emit(self, description="Unknown State", status="in_progress", done=False):
if self.event_emitter:
await self.event_emitter(
{
"type": "status",
"data": {
"status": status,
"description": description,
"done": done,
},
}
)
class MessageEventEmitter:
def __init__(self, event_emitter: Callable[[dict], Any] = None):
self.event_emitter = event_emitter
async def emit(self, content="Some message"):
if self.event_emitter:
await self.event_emitter(
{
"type": "message",
"data": {
"content": content,
},
}
)
class Tools:
class Valves(BaseModel):
LIGHTRAG_SERVER_URL: str = Field(
default="http://localhost:9621/query",
description="The base URL for the LightRag server",
)
MODE: Literal["naive", "local", "global", "hybrid"] = Field(
default="hybrid",
description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
)
ONLY_NEED_CONTEXT: bool = Field(
default=False,
description="If True, only the context is needed from the LightRag response",
)
DEBUG_MODE: bool = Field(
default=False,
description="If True, debugging information will be emitted",
)
KEY: str = Field(
default="",
description="Optional Bearer Key for authentication",
)
MAX_ENTITIES: int = Field(
default=5,
description="Maximum number of entities to keep",
)
MAX_RELATIONSHIPS: int = Field(
default=5,
description="Maximum number of relationships to keep",
)
MAX_SOURCES: int = Field(
default=3,
description="Maximum number of sources to keep",
)
def __init__(self):
self.valves = self.Valves()
self.headers = {
"Content-Type": "application/json",
"User-Agent": "LightRag-Tool/1.0",
}
async def query_lightrag(
self,
query: str,
__event_emitter__: Callable[[dict], Any] = None,
) -> str:
"""
Query the LightRag server and retrieve information.
This function must be called before answering the user question
:params query: The query string to send to the LightRag server.
:return: The response from the LightRag server in Markdown format or raw response.
"""
self.status_emitter = StatusEventEmitter(__event_emitter__)
self.message_emitter = MessageEventEmitter(__event_emitter__)
lightrag_url = self.valves.LIGHTRAG_SERVER_URL
payload = {
"query": query,
"mode": str(self.valves.MODE),
"stream": False,
"only_need_context": self.valves.ONLY_NEED_CONTEXT,
}
await self.status_emitter.emit("Initializing Lightrag query..")
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"### Debug Mode Active\n\nDebugging information will be displayed.\n"
)
await self.message_emitter.emit(
"#### Payload Sent to LightRag Server\n```json\n"
+ json.dumps(payload, indent=4)
+ "\n```\n"
)
# Add Bearer Key to headers if provided
if self.valves.KEY:
self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
try:
await self.status_emitter.emit("Sending request to LightRag server")
response = requests.post(
lightrag_url, json=payload, headers=self.headers, timeout=120
)
response.raise_for_status()
data = response.json()
await self.status_emitter.emit(
status="complete",
description="LightRag query Succeeded",
done=True,
)
# Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
if self.valves.ONLY_NEED_CONTEXT:
try:
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"#### LightRag Server Response\n```json\n"
+ data["response"]
+ "\n```\n"
)
except Exception as ex:
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"#### Exception\n" + str(ex) + "\n"
)
return f"Exception: {ex}"
return data["response"]
else:
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"#### LightRag Server Response\n```json\n"
+ data["response"]
+ "\n```\n"
)
await self.status_emitter.emit("Lightrag query success")
return data["response"]
except requests.exceptions.RequestException as e:
await self.status_emitter.emit(
status="error",
description=f"Error during LightRag query: {str(e)}",
done=True,
)
return json.dumps({"error": str(e)})
def extract_code_blocks(
self, text: str, return_remaining_text: bool = False
) -> Union[List[dict], Tuple[List[dict], str]]:
"""
This function extracts code blocks from a given text and optionally returns the text without code blocks.
Parameters:
text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
return_remaining_text (bool): If True, also returns the text with code blocks removed.
Returns:
Union[List[dict], Tuple[List[dict], str]]:
- If return_remaining_text is False: Returns only the list of code block dictionaries
- If return_remaining_text is True: Returns a tuple containing:
* List of code block dictionaries
* String containing the text with all code blocks removed
Each code block dictionary contains:
- 'index' (int): The index of the code block in the text
- 'file_name' (str): The name of the file extracted from the preceding line, if available
- 'content' (str): The content of the code block
- 'type' (str): The type of the code block
- 'is_complete' (bool): True if the block has a closing tag, False otherwise
"""
remaining = text
bloc_index = 0
first_index = 0
indices = []
text_without_blocks = text
# Find all code block delimiters
while len(remaining) > 0:
try:
index = remaining.index("```")
indices.append(index + first_index)
remaining = remaining[index + 3 :]
first_index += index + 3
bloc_index += 1
except Exception:
if bloc_index % 2 == 1:
index = len(remaining)
indices.append(index)
remaining = ""
code_blocks = []
is_start = True
# Process code blocks and build text without blocks if requested
if return_remaining_text:
text_parts = []
last_end = 0
for index, code_delimiter_position in enumerate(indices):
if is_start:
block_infos = {
"index": len(code_blocks),
"file_name": "",
"section": "",
"content": "",
"type": "",
"is_complete": False,
}
# Store text before code block if returning remaining text
if return_remaining_text:
text_parts.append(text[last_end:code_delimiter_position].strip())
# Check the preceding line for file name
preceding_text = text[:code_delimiter_position].strip().splitlines()
if preceding_text:
last_line = preceding_text[-1].strip()
if last_line.startswith("<file_name>") and last_line.endswith(
"</file_name>"
):
file_name = last_line[
len("<file_name>") : -len("</file_name>")
].strip()
block_infos["file_name"] = file_name
elif last_line.startswith("## filename:"):
file_name = last_line[len("## filename:") :].strip()
block_infos["file_name"] = file_name
if last_line.startswith("<section>") and last_line.endswith(
"</section>"
):
section = last_line[
len("<section>") : -len("</section>")
].strip()
block_infos["section"] = section
sub_text = text[code_delimiter_position + 3 :]
if len(sub_text) > 0:
try:
find_space = sub_text.index(" ")
except Exception:
find_space = int(1e10)
try:
find_return = sub_text.index("\n")
except Exception:
find_return = int(1e10)
next_index = min(find_return, find_space)
if "{" in sub_text[:next_index]:
next_index = 0
start_pos = next_index
if code_delimiter_position + 3 < len(text) and text[
code_delimiter_position + 3
] in ["\n", " ", "\t"]:
block_infos["type"] = "language-specific"
else:
block_infos["type"] = sub_text[:next_index]
if index + 1 < len(indices):
next_pos = indices[index + 1] - code_delimiter_position
if (
next_pos - 3 < len(sub_text)
and sub_text[next_pos - 3] == "`"
):
block_infos["content"] = sub_text[
start_pos : next_pos - 3
].strip()
block_infos["is_complete"] = True
else:
block_infos["content"] = sub_text[
start_pos:next_pos
].strip()
block_infos["is_complete"] = False
if return_remaining_text:
last_end = indices[index + 1] + 3
else:
block_infos["content"] = sub_text[start_pos:].strip()
block_infos["is_complete"] = False
if return_remaining_text:
last_end = len(text)
code_blocks.append(block_infos)
is_start = False
else:
is_start = True
if return_remaining_text:
# Add any remaining text after the last code block
if last_end < len(text):
text_parts.append(text[last_end:].strip())
# Join all non-code parts with newlines
text_without_blocks = "\n".join(filter(None, text_parts))
return code_blocks, text_without_blocks
return code_blocks
def clean(self, csv_content: str):
lines = csv_content.splitlines()
if lines:
# Remove spaces around headers and ensure no spaces between commas
header = ",".join([col.strip() for col in lines[0].split(",")])
lines[0] = header # Replace the first line with the cleaned header
csv_content = "\n".join(lines)
return csv_content

View File

@@ -83,11 +83,11 @@ class StorageNameSpace:
namespace: str namespace: str
global_config: dict[str, Any] global_config: dict[str, Any]
async def index_done_callback(self): async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing""" """Commit the storage operations after indexing"""
pass pass
async def query_done_callback(self): async def query_done_callback(self) -> None:
"""Commit the storage operations after querying""" """Commit the storage operations after querying"""
pass pass

View File

@@ -6,7 +6,7 @@ import configparser
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Callable, Optional, Type, Union, cast from typing import Any, Callable, Optional, Union, cast
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -304,7 +304,7 @@ class LightRAG:
- random_seed: Seed value for reproducibility. - random_seed: Seed value for reproducibility.
""" """
embedding_func: Union[EmbeddingFunc, None] = None embedding_func: EmbeddingFunc | None = None
"""Function for computing text embeddings. Must be set before use.""" """Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32 embedding_batch_num: int = 32
@@ -344,10 +344,8 @@ class LightRAG:
# Extensions # Extensions
addon_params: dict[str, Any] = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
# extension """Dictionary for additional parameters and extensions."""
addon_params: dict[str, Any] = field(default_factory=dict)
convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json convert_response_to_json
) )
@@ -445,77 +443,74 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM # Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func self.embedding_func
) )
# Initialize all storages # Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage) self._get_storage_class(self.kv_storage)
) ) # type: ignore
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage self.vector_storage
) ) # type: ignore
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage self.graph_storage
) ) # type: ignore
self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config self.key_string_value_json_storage_cls, global_config=global_config
) )
self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config self.vector_db_storage_cls, global_config=global_config
) )
self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config self.graph_storage_cls, global_config=global_config
) )
# Initialize document status storage # Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.entities_vdb = self.vector_db_storage_cls( # type: ignore self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"entity_name"}, meta_fields={"entity_name"},
) )
self.relationships_vdb = self.vector_db_storage_cls( # type: ignore self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}, meta_fields={"src_id", "tgt_id"},
) )
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
), ),
@@ -535,16 +530,16 @@ class LightRAG:
): ):
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
else: else:
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
self.llm_model_func, # type: ignore self.llm_model_func, # type: ignore
hashing_kv=hashing_kv, hashing_kv=hashing_kv,
**self.llm_model_kwargs, **self.llm_model_kwargs,
) )
@@ -836,32 +831,32 @@ class LightRAG:
raise e raise e
async def _insert_done(self): async def _insert_done(self):
tasks = [] tasks = [
for storage_inst in [ cast(StorageNameSpace, storage_inst).index_done_callback()
self.full_docs, for storage_inst in [ # type: ignore
self.text_chunks, self.full_docs,
self.llm_response_cache, self.text_chunks,
self.entities_vdb, self.llm_response_cache,
self.relationships_vdb, self.entities_vdb,
self.chunks_vdb, self.relationships_vdb,
self.chunk_entity_relation_graph, self.chunks_vdb,
]: self.chunk_entity_relation_graph,
if storage_inst is None: ]
continue if storage_inst is not None
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): def insert_custom_kg(self, custom_kg: dict[str, Any]):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg)) return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]): async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
update_storage = False update_storage = False
try: try:
# Insert chunks into vector storage # Insert chunks into vector storage
all_chunks_data = {} all_chunks_data: dict[str, dict[str, str]] = {}
chunk_to_source_map = {} chunk_to_source_map: dict[str, str] = {}
for chunk_data in custom_kg.get("chunks", []): for chunk_data in custom_kg.get("chunks", {}):
chunk_content = chunk_data["content"] chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"] source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-") chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -871,13 +866,13 @@ class LightRAG:
chunk_to_source_map[source_id] = chunk_id chunk_to_source_map[source_id] = chunk_id
update_storage = True update_storage = True
if self.chunks_vdb is not None and all_chunks_data: if all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data) await self.chunks_vdb.upsert(all_chunks_data)
if self.text_chunks is not None and all_chunks_data: if all_chunks_data:
await self.text_chunks.upsert(all_chunks_data) await self.text_chunks.upsert(all_chunks_data)
# Insert entities into knowledge graph # Insert entities into knowledge graph
all_entities_data = [] all_entities_data: list[dict[str, str]] = []
for entity_data in custom_kg.get("entities", []): for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"' entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN") entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -893,7 +888,7 @@ class LightRAG:
) )
# Prepare node data # Prepare node data
node_data = { node_data: dict[str, str] = {
"entity_type": entity_type, "entity_type": entity_type,
"description": description, "description": description,
"source_id": source_id, "source_id": source_id,
@@ -907,7 +902,7 @@ class LightRAG:
update_storage = True update_storage = True
# Insert relationships into knowledge graph # Insert relationships into knowledge graph
all_relationships_data = [] all_relationships_data: list[dict[str, str]] = []
for relationship_data in custom_kg.get("relationships", []): for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"' src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"' tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -949,7 +944,7 @@ class LightRAG:
"source_id": source_id, "source_id": source_id,
}, },
) )
edge_data = { edge_data: dict[str, str] = {
"src_id": src_id, "src_id": src_id,
"tgt_id": tgt_id, "tgt_id": tgt_id,
"description": description, "description": description,
@@ -959,19 +954,17 @@ class LightRAG:
update_storage = True update_storage = True
# Insert entities into vector storage if needed # Insert entities into vector storage if needed
if self.entities_vdb is not None: data_for_vdb = {
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): { compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"], "content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"], "entity_name": dp["entity_name"],
} }
for dp in all_entities_data for dp in all_entities_data
} }
await self.entities_vdb.upsert(data_for_vdb) await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed # Insert relationships into vector storage if needed
if self.relationships_vdb is not None: data_for_vdb = {
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"], "src_id": dp["src_id"],
"tgt_id": dp["tgt_id"], "tgt_id": dp["tgt_id"],
@@ -982,18 +975,49 @@ class LightRAG:
} }
for dp in all_relationships_data for dp in all_relationships_data
} }
await self.relationships_vdb.upsert(data_for_vdb) await self.relationships_vdb.upsert(data_for_vdb)
finally: finally:
if update_storage: if update_storage:
await self._insert_done() await self._insert_done()
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()): def query(
self,
query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None
) -> str:
"""
Perform a sync query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, prompt, param)) return loop.run_until_complete(self.aquery(query, param, prompt))
async def aquery( async def aquery(
self, query: str, prompt: str = "", param: QueryParam = QueryParam() self,
): query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None,
) -> str:
"""
Perform a async query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
if param.mode in ["local", "global", "hybrid"]: if param.mode in ["local", "global", "hybrid"]:
response = await kg_query( response = await kg_query(
query, query,

View File

@@ -295,8 +295,8 @@ async def extract_entities(
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage, entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
global_config: dict, global_config: dict[str, str],
llm_response_cache: BaseKVStorage = None, llm_response_cache: BaseKVStorage | None = None,
) -> Union[BaseGraphStorage, None]: ) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -563,15 +563,15 @@ async def extract_entities(
async def kg_query( async def kg_query(
query, query: str,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict[str, str],
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage | None = None,
prompt: str = "", prompt: str | None = None,
) -> str: ) -> str:
# Handle cache # Handle cache
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
@@ -681,8 +681,8 @@ async def kg_query(
async def extract_keywords_only( async def extract_keywords_only(
text: str, text: str,
param: QueryParam, param: QueryParam,
global_config: dict, global_config: dict[str, str],
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]: ) -> tuple[list[str], list[str]]:
""" """
Extract high-level and low-level keywords from the given 'text' using the LLM. Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -778,8 +778,8 @@ async def mix_kg_vector_query(
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict[str, str],
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage | None = None,
) -> str: ) -> str:
""" """
Hybrid retrieval implementation combining knowledge graph and vector search. Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1499,12 +1499,12 @@ def combine_contexts(entities, relationships, sources):
async def naive_query( async def naive_query(
query, query: str,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict[str, str],
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage | None = None,
): ):
# Handle cache # Handle cache
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
@@ -1606,8 +1606,8 @@ async def kg_query_with_keywords(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict[str, str],
hashing_kv: BaseKVStorage = None, hashing_kv: BaseKVStorage | None = None,
) -> str: ) -> str:
""" """
Refactored kg_query that does NOT extract keywords by itself. Refactored kg_query that does NOT extract keywords by itself.

View File

@@ -128,7 +128,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
return hashlib.md5(args_str.encode()).hexdigest() return hashlib.md5(args_str.encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""): def compute_mdhash_id(content: str, prefix: str = "") -> str:
"""
Compute a unique ID for a given content string.
The ID is a combination of the given prefix and the MD5 hash of the content string.
"""
return prefix + md5(content.encode()).hexdigest() return prefix + md5(content.encode()).hexdigest()