cleaning the message and project no needed
This commit is contained in:
@@ -428,9 +428,9 @@ And using a routine to process news documents.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
rag = LightRAG(..)
|
rag = LightRAG(..)
|
||||||
await rag.apipeline_enqueue_documents(string_or_strings)
|
await rag.apipeline_enqueue_documents(input)
|
||||||
# Your routine in loop
|
# Your routine in loop
|
||||||
await rag.apipeline_process_enqueue_documents(string_or_strings)
|
await rag.apipeline_process_enqueue_documents(input)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Separate Keyword Extraction
|
### Separate Keyword Extraction
|
||||||
|
@@ -113,7 +113,24 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||||
rag.set_storage_client(db_client=oracle_db)
|
|
||||||
|
for storage in [
|
||||||
|
rag.vector_db_storage_cls,
|
||||||
|
rag.graph_storage_cls,
|
||||||
|
rag.doc_status,
|
||||||
|
rag.full_docs,
|
||||||
|
rag.text_chunks,
|
||||||
|
rag.llm_response_cache,
|
||||||
|
rag.key_string_value_json_storage_cls,
|
||||||
|
rag.chunks_vdb,
|
||||||
|
rag.relationships_vdb,
|
||||||
|
rag.entities_vdb,
|
||||||
|
rag.graph_storage_cls,
|
||||||
|
rag.chunk_entity_relation_graph,
|
||||||
|
rag.llm_response_cache,
|
||||||
|
]:
|
||||||
|
# set client
|
||||||
|
storage.db = oracle_db
|
||||||
|
|
||||||
# Extract and Insert into LightRAG storage
|
# Extract and Insert into LightRAG storage
|
||||||
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
||||||
|
@@ -1,358 +0,0 @@
|
|||||||
"""
|
|
||||||
OpenWebui Lightrag Integration Tool
|
|
||||||
==================================
|
|
||||||
|
|
||||||
This tool enables the integration and use of Lightrag within the OpenWebui environment,
|
|
||||||
providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
|
|
||||||
|
|
||||||
Author: ParisNeo (parisneoai@gmail.com)
|
|
||||||
Social:
|
|
||||||
- Twitter: @ParisNeo_AI
|
|
||||||
- Reddit: r/lollms
|
|
||||||
- Instagram: https://www.instagram.com/parisneo_ai/
|
|
||||||
|
|
||||||
License: Apache 2.0
|
|
||||||
Copyright (c) 2024-2025 ParisNeo
|
|
||||||
|
|
||||||
This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
|
|
||||||
For more information, visit: https://github.com/ParisNeo/lollms
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- Python 3.8+
|
|
||||||
- OpenWebui
|
|
||||||
- Lightrag
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Tool version
|
|
||||||
__version__ = "1.0.0"
|
|
||||||
__author__ = "ParisNeo"
|
|
||||||
__author_email__ = "parisneoai@gmail.com"
|
|
||||||
__description__ = "Lightrag integration for OpenWebui"
|
|
||||||
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Callable, Any, Literal, Union, List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class StatusEventEmitter:
|
|
||||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit(self, description="Unknown State", status="in_progress", done=False):
|
|
||||||
if self.event_emitter:
|
|
||||||
await self.event_emitter(
|
|
||||||
{
|
|
||||||
"type": "status",
|
|
||||||
"data": {
|
|
||||||
"status": status,
|
|
||||||
"description": description,
|
|
||||||
"done": done,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageEventEmitter:
|
|
||||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit(self, content="Some message"):
|
|
||||||
if self.event_emitter:
|
|
||||||
await self.event_emitter(
|
|
||||||
{
|
|
||||||
"type": "message",
|
|
||||||
"data": {
|
|
||||||
"content": content,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Tools:
|
|
||||||
class Valves(BaseModel):
|
|
||||||
LIGHTRAG_SERVER_URL: str = Field(
|
|
||||||
default="http://localhost:9621/query",
|
|
||||||
description="The base URL for the LightRag server",
|
|
||||||
)
|
|
||||||
MODE: Literal["naive", "local", "global", "hybrid"] = Field(
|
|
||||||
default="hybrid",
|
|
||||||
description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
|
|
||||||
)
|
|
||||||
ONLY_NEED_CONTEXT: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="If True, only the context is needed from the LightRag response",
|
|
||||||
)
|
|
||||||
DEBUG_MODE: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="If True, debugging information will be emitted",
|
|
||||||
)
|
|
||||||
KEY: str = Field(
|
|
||||||
default="",
|
|
||||||
description="Optional Bearer Key for authentication",
|
|
||||||
)
|
|
||||||
MAX_ENTITIES: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="Maximum number of entities to keep",
|
|
||||||
)
|
|
||||||
MAX_RELATIONSHIPS: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="Maximum number of relationships to keep",
|
|
||||||
)
|
|
||||||
MAX_SOURCES: int = Field(
|
|
||||||
default=3,
|
|
||||||
description="Maximum number of sources to keep",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.valves = self.Valves()
|
|
||||||
self.headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"User-Agent": "LightRag-Tool/1.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def query_lightrag(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
__event_emitter__: Callable[[dict], Any] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Query the LightRag server and retrieve information.
|
|
||||||
This function must be called before answering the user question
|
|
||||||
:params query: The query string to send to the LightRag server.
|
|
||||||
:return: The response from the LightRag server in Markdown format or raw response.
|
|
||||||
"""
|
|
||||||
self.status_emitter = StatusEventEmitter(__event_emitter__)
|
|
||||||
self.message_emitter = MessageEventEmitter(__event_emitter__)
|
|
||||||
|
|
||||||
lightrag_url = self.valves.LIGHTRAG_SERVER_URL
|
|
||||||
payload = {
|
|
||||||
"query": query,
|
|
||||||
"mode": str(self.valves.MODE),
|
|
||||||
"stream": False,
|
|
||||||
"only_need_context": self.valves.ONLY_NEED_CONTEXT,
|
|
||||||
}
|
|
||||||
await self.status_emitter.emit("Initializing Lightrag query..")
|
|
||||||
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"### Debug Mode Active\n\nDebugging information will be displayed.\n"
|
|
||||||
)
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### Payload Sent to LightRag Server\n```json\n"
|
|
||||||
+ json.dumps(payload, indent=4)
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add Bearer Key to headers if provided
|
|
||||||
if self.valves.KEY:
|
|
||||||
self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.status_emitter.emit("Sending request to LightRag server")
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
lightrag_url, json=payload, headers=self.headers, timeout=120
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
await self.status_emitter.emit(
|
|
||||||
status="complete",
|
|
||||||
description="LightRag query Succeeded",
|
|
||||||
done=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
|
|
||||||
if self.valves.ONLY_NEED_CONTEXT:
|
|
||||||
try:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### LightRag Server Response\n```json\n"
|
|
||||||
+ data["response"]
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### Exception\n" + str(ex) + "\n"
|
|
||||||
)
|
|
||||||
return f"Exception: {ex}"
|
|
||||||
return data["response"]
|
|
||||||
else:
|
|
||||||
if self.valves.DEBUG_MODE:
|
|
||||||
await self.message_emitter.emit(
|
|
||||||
"#### LightRag Server Response\n```json\n"
|
|
||||||
+ data["response"]
|
|
||||||
+ "\n```\n"
|
|
||||||
)
|
|
||||||
await self.status_emitter.emit("Lightrag query success")
|
|
||||||
return data["response"]
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
await self.status_emitter.emit(
|
|
||||||
status="error",
|
|
||||||
description=f"Error during LightRag query: {str(e)}",
|
|
||||||
done=True,
|
|
||||||
)
|
|
||||||
return json.dumps({"error": str(e)})
|
|
||||||
|
|
||||||
def extract_code_blocks(
|
|
||||||
self, text: str, return_remaining_text: bool = False
|
|
||||||
) -> Union[List[dict], Tuple[List[dict], str]]:
|
|
||||||
"""
|
|
||||||
This function extracts code blocks from a given text and optionally returns the text without code blocks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
|
|
||||||
return_remaining_text (bool): If True, also returns the text with code blocks removed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[List[dict], Tuple[List[dict], str]]:
|
|
||||||
- If return_remaining_text is False: Returns only the list of code block dictionaries
|
|
||||||
- If return_remaining_text is True: Returns a tuple containing:
|
|
||||||
* List of code block dictionaries
|
|
||||||
* String containing the text with all code blocks removed
|
|
||||||
|
|
||||||
Each code block dictionary contains:
|
|
||||||
- 'index' (int): The index of the code block in the text
|
|
||||||
- 'file_name' (str): The name of the file extracted from the preceding line, if available
|
|
||||||
- 'content' (str): The content of the code block
|
|
||||||
- 'type' (str): The type of the code block
|
|
||||||
- 'is_complete' (bool): True if the block has a closing tag, False otherwise
|
|
||||||
"""
|
|
||||||
remaining = text
|
|
||||||
bloc_index = 0
|
|
||||||
first_index = 0
|
|
||||||
indices = []
|
|
||||||
text_without_blocks = text
|
|
||||||
|
|
||||||
# Find all code block delimiters
|
|
||||||
while len(remaining) > 0:
|
|
||||||
try:
|
|
||||||
index = remaining.index("```")
|
|
||||||
indices.append(index + first_index)
|
|
||||||
remaining = remaining[index + 3 :]
|
|
||||||
first_index += index + 3
|
|
||||||
bloc_index += 1
|
|
||||||
except Exception:
|
|
||||||
if bloc_index % 2 == 1:
|
|
||||||
index = len(remaining)
|
|
||||||
indices.append(index)
|
|
||||||
remaining = ""
|
|
||||||
|
|
||||||
code_blocks = []
|
|
||||||
is_start = True
|
|
||||||
|
|
||||||
# Process code blocks and build text without blocks if requested
|
|
||||||
if return_remaining_text:
|
|
||||||
text_parts = []
|
|
||||||
last_end = 0
|
|
||||||
|
|
||||||
for index, code_delimiter_position in enumerate(indices):
|
|
||||||
if is_start:
|
|
||||||
block_infos = {
|
|
||||||
"index": len(code_blocks),
|
|
||||||
"file_name": "",
|
|
||||||
"section": "",
|
|
||||||
"content": "",
|
|
||||||
"type": "",
|
|
||||||
"is_complete": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store text before code block if returning remaining text
|
|
||||||
if return_remaining_text:
|
|
||||||
text_parts.append(text[last_end:code_delimiter_position].strip())
|
|
||||||
|
|
||||||
# Check the preceding line for file name
|
|
||||||
preceding_text = text[:code_delimiter_position].strip().splitlines()
|
|
||||||
if preceding_text:
|
|
||||||
last_line = preceding_text[-1].strip()
|
|
||||||
if last_line.startswith("<file_name>") and last_line.endswith(
|
|
||||||
"</file_name>"
|
|
||||||
):
|
|
||||||
file_name = last_line[
|
|
||||||
len("<file_name>") : -len("</file_name>")
|
|
||||||
].strip()
|
|
||||||
block_infos["file_name"] = file_name
|
|
||||||
elif last_line.startswith("## filename:"):
|
|
||||||
file_name = last_line[len("## filename:") :].strip()
|
|
||||||
block_infos["file_name"] = file_name
|
|
||||||
if last_line.startswith("<section>") and last_line.endswith(
|
|
||||||
"</section>"
|
|
||||||
):
|
|
||||||
section = last_line[
|
|
||||||
len("<section>") : -len("</section>")
|
|
||||||
].strip()
|
|
||||||
block_infos["section"] = section
|
|
||||||
|
|
||||||
sub_text = text[code_delimiter_position + 3 :]
|
|
||||||
if len(sub_text) > 0:
|
|
||||||
try:
|
|
||||||
find_space = sub_text.index(" ")
|
|
||||||
except Exception:
|
|
||||||
find_space = int(1e10)
|
|
||||||
try:
|
|
||||||
find_return = sub_text.index("\n")
|
|
||||||
except Exception:
|
|
||||||
find_return = int(1e10)
|
|
||||||
next_index = min(find_return, find_space)
|
|
||||||
if "{" in sub_text[:next_index]:
|
|
||||||
next_index = 0
|
|
||||||
start_pos = next_index
|
|
||||||
|
|
||||||
if code_delimiter_position + 3 < len(text) and text[
|
|
||||||
code_delimiter_position + 3
|
|
||||||
] in ["\n", " ", "\t"]:
|
|
||||||
block_infos["type"] = "language-specific"
|
|
||||||
else:
|
|
||||||
block_infos["type"] = sub_text[:next_index]
|
|
||||||
|
|
||||||
if index + 1 < len(indices):
|
|
||||||
next_pos = indices[index + 1] - code_delimiter_position
|
|
||||||
if (
|
|
||||||
next_pos - 3 < len(sub_text)
|
|
||||||
and sub_text[next_pos - 3] == "`"
|
|
||||||
):
|
|
||||||
block_infos["content"] = sub_text[
|
|
||||||
start_pos : next_pos - 3
|
|
||||||
].strip()
|
|
||||||
block_infos["is_complete"] = True
|
|
||||||
else:
|
|
||||||
block_infos["content"] = sub_text[
|
|
||||||
start_pos:next_pos
|
|
||||||
].strip()
|
|
||||||
block_infos["is_complete"] = False
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
last_end = indices[index + 1] + 3
|
|
||||||
else:
|
|
||||||
block_infos["content"] = sub_text[start_pos:].strip()
|
|
||||||
block_infos["is_complete"] = False
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
last_end = len(text)
|
|
||||||
|
|
||||||
code_blocks.append(block_infos)
|
|
||||||
is_start = False
|
|
||||||
else:
|
|
||||||
is_start = True
|
|
||||||
|
|
||||||
if return_remaining_text:
|
|
||||||
# Add any remaining text after the last code block
|
|
||||||
if last_end < len(text):
|
|
||||||
text_parts.append(text[last_end:].strip())
|
|
||||||
# Join all non-code parts with newlines
|
|
||||||
text_without_blocks = "\n".join(filter(None, text_parts))
|
|
||||||
return code_blocks, text_without_blocks
|
|
||||||
|
|
||||||
return code_blocks
|
|
||||||
|
|
||||||
def clean(self, csv_content: str):
|
|
||||||
lines = csv_content.splitlines()
|
|
||||||
if lines:
|
|
||||||
# Remove spaces around headers and ensure no spaces between commas
|
|
||||||
header = ",".join([col.strip() for col in lines[0].split(",")])
|
|
||||||
lines[0] = header # Replace the first line with the cleaned header
|
|
||||||
csv_content = "\n".join(lines)
|
|
||||||
return csv_content
|
|
@@ -83,11 +83,11 @@ class StorageNameSpace:
|
|||||||
namespace: str
|
namespace: str
|
||||||
global_config: dict[str, Any]
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self) -> None:
|
||||||
"""Commit the storage operations after indexing"""
|
"""Commit the storage operations after indexing"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def query_done_callback(self):
|
async def query_done_callback(self) -> None:
|
||||||
"""Commit the storage operations after querying"""
|
"""Commit the storage operations after querying"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@@ -6,7 +6,7 @@ import configparser
|
|||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Optional, Type, Union, cast
|
from typing import Any, Callable, Optional, Union, cast
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -304,7 +304,7 @@ class LightRAG:
|
|||||||
- random_seed: Seed value for reproducibility.
|
- random_seed: Seed value for reproducibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding_func: Union[EmbeddingFunc, None] = None
|
embedding_func: EmbeddingFunc | None = None
|
||||||
"""Function for computing text embeddings. Must be set before use."""
|
"""Function for computing text embeddings. Must be set before use."""
|
||||||
|
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
@@ -344,10 +344,8 @@ class LightRAG:
|
|||||||
|
|
||||||
# Extensions
|
# Extensions
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Dictionary for additional parameters and extensions."""
|
|
||||||
|
|
||||||
# extension
|
"""Dictionary for additional parameters and extensions."""
|
||||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
|
||||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||||
convert_response_to_json
|
convert_response_to_json
|
||||||
)
|
)
|
||||||
@@ -445,77 +443,74 @@ class LightRAG:
|
|||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
# Init LLM
|
# Init LLM
|
||||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||||
self.embedding_func
|
self.embedding_func
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize all storages
|
# Initialize all storages
|
||||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore
|
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||||
self._get_storage_class(self.kv_storage)
|
self._get_storage_class(self.kv_storage)
|
||||||
)
|
) # type: ignore
|
||||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore
|
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
||||||
self.vector_storage
|
self.vector_storage
|
||||||
)
|
) # type: ignore
|
||||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore
|
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
||||||
self.graph_storage
|
self.graph_storage
|
||||||
)
|
) # type: ignore
|
||||||
|
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||||
self.key_string_value_json_storage_cls = partial( # type: ignore
|
|
||||||
self.key_string_value_json_storage_cls, global_config=global_config
|
self.key_string_value_json_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
self.vector_db_storage_cls = partial( # type: ignore
|
||||||
self.vector_db_storage_cls = partial( # type: ignore
|
|
||||||
self.vector_db_storage_cls, global_config=global_config
|
self.vector_db_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
self.graph_storage_cls = partial( # type: ignore
|
||||||
self.graph_storage_cls = partial( # type: ignore
|
|
||||||
self.graph_storage_cls, global_config=global_config
|
self.graph_storage_cls, global_config=global_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize document status storage
|
# Initialize document status storage
|
||||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||||
|
|
||||||
self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore
|
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.entities_vdb = self.vector_db_storage_cls( # type: ignore
|
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"entity_name"},
|
meta_fields={"entity_name"},
|
||||||
)
|
)
|
||||||
self.relationships_vdb = self.vector_db_storage_cls( # type: ignore
|
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"src_id", "tgt_id"},
|
meta_fields={"src_id", "tgt_id"},
|
||||||
)
|
)
|
||||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||||
),
|
),
|
||||||
@@ -535,16 +530,16 @@ class LightRAG:
|
|||||||
):
|
):
|
||||||
hashing_kv = self.llm_response_cache
|
hashing_kv = self.llm_response_cache
|
||||||
else:
|
else:
|
||||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
self.llm_model_func, # type: ignore
|
self.llm_model_func, # type: ignore
|
||||||
hashing_kv=hashing_kv,
|
hashing_kv=hashing_kv,
|
||||||
**self.llm_model_kwargs,
|
**self.llm_model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -836,32 +831,32 @@ class LightRAG:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _insert_done(self):
|
async def _insert_done(self):
|
||||||
tasks = []
|
tasks = [
|
||||||
for storage_inst in [
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||||
self.full_docs,
|
for storage_inst in [ # type: ignore
|
||||||
self.text_chunks,
|
self.full_docs,
|
||||||
self.llm_response_cache,
|
self.text_chunks,
|
||||||
self.entities_vdb,
|
self.llm_response_cache,
|
||||||
self.relationships_vdb,
|
self.entities_vdb,
|
||||||
self.chunks_vdb,
|
self.relationships_vdb,
|
||||||
self.chunk_entity_relation_graph,
|
self.chunks_vdb,
|
||||||
]:
|
self.chunk_entity_relation_graph,
|
||||||
if storage_inst is None:
|
]
|
||||||
continue
|
if storage_inst is not None
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
|
def insert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||||
|
|
||||||
async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
|
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||||
update_storage = False
|
update_storage = False
|
||||||
try:
|
try:
|
||||||
# Insert chunks into vector storage
|
# Insert chunks into vector storage
|
||||||
all_chunks_data = {}
|
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||||
chunk_to_source_map = {}
|
chunk_to_source_map: dict[str, str] = {}
|
||||||
for chunk_data in custom_kg.get("chunks", []):
|
for chunk_data in custom_kg.get("chunks", {}):
|
||||||
chunk_content = chunk_data["content"]
|
chunk_content = chunk_data["content"]
|
||||||
source_id = chunk_data["source_id"]
|
source_id = chunk_data["source_id"]
|
||||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||||
@@ -871,13 +866,13 @@ class LightRAG:
|
|||||||
chunk_to_source_map[source_id] = chunk_id
|
chunk_to_source_map[source_id] = chunk_id
|
||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
if self.chunks_vdb is not None and all_chunks_data:
|
if all_chunks_data:
|
||||||
await self.chunks_vdb.upsert(all_chunks_data)
|
await self.chunks_vdb.upsert(all_chunks_data)
|
||||||
if self.text_chunks is not None and all_chunks_data:
|
if all_chunks_data:
|
||||||
await self.text_chunks.upsert(all_chunks_data)
|
await self.text_chunks.upsert(all_chunks_data)
|
||||||
|
|
||||||
# Insert entities into knowledge graph
|
# Insert entities into knowledge graph
|
||||||
all_entities_data = []
|
all_entities_data: list[dict[str, str]] = []
|
||||||
for entity_data in custom_kg.get("entities", []):
|
for entity_data in custom_kg.get("entities", []):
|
||||||
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
||||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||||
@@ -893,7 +888,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare node data
|
# Prepare node data
|
||||||
node_data = {
|
node_data: dict[str, str] = {
|
||||||
"entity_type": entity_type,
|
"entity_type": entity_type,
|
||||||
"description": description,
|
"description": description,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
@@ -907,7 +902,7 @@ class LightRAG:
|
|||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
# Insert relationships into knowledge graph
|
# Insert relationships into knowledge graph
|
||||||
all_relationships_data = []
|
all_relationships_data: list[dict[str, str]] = []
|
||||||
for relationship_data in custom_kg.get("relationships", []):
|
for relationship_data in custom_kg.get("relationships", []):
|
||||||
src_id = f'"{relationship_data["src_id"].upper()}"'
|
src_id = f'"{relationship_data["src_id"].upper()}"'
|
||||||
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
||||||
@@ -949,7 +944,7 @@ class LightRAG:
|
|||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
edge_data = {
|
edge_data: dict[str, str] = {
|
||||||
"src_id": src_id,
|
"src_id": src_id,
|
||||||
"tgt_id": tgt_id,
|
"tgt_id": tgt_id,
|
||||||
"description": description,
|
"description": description,
|
||||||
@@ -959,19 +954,17 @@ class LightRAG:
|
|||||||
update_storage = True
|
update_storage = True
|
||||||
|
|
||||||
# Insert entities into vector storage if needed
|
# Insert entities into vector storage if needed
|
||||||
if self.entities_vdb is not None:
|
data_for_vdb = {
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
"content": dp["entity_name"] + dp["description"],
|
"content": dp["entity_name"] + dp["description"],
|
||||||
"entity_name": dp["entity_name"],
|
"entity_name": dp["entity_name"],
|
||||||
}
|
}
|
||||||
for dp in all_entities_data
|
for dp in all_entities_data
|
||||||
}
|
}
|
||||||
await self.entities_vdb.upsert(data_for_vdb)
|
await self.entities_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
# Insert relationships into vector storage if needed
|
# Insert relationships into vector storage if needed
|
||||||
if self.relationships_vdb is not None:
|
data_for_vdb = {
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
@@ -982,18 +975,49 @@ class LightRAG:
|
|||||||
}
|
}
|
||||||
for dp in all_relationships_data
|
for dp in all_relationships_data
|
||||||
}
|
}
|
||||||
await self.relationships_vdb.upsert(data_for_vdb)
|
await self.relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if update_storage:
|
if update_storage:
|
||||||
await self._insert_done()
|
await self._insert_done()
|
||||||
|
|
||||||
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
|
def query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
param: QueryParam = QueryParam(),
|
||||||
|
prompt: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Perform a sync query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to be executed.
|
||||||
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the query execution.
|
||||||
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.aquery(query, prompt, param))
|
return loop.run_until_complete(self.aquery(query, param, prompt))
|
||||||
|
|
||||||
async def aquery(
|
async def aquery(
|
||||||
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
|
self,
|
||||||
):
|
query: str,
|
||||||
|
param: QueryParam = QueryParam(),
|
||||||
|
prompt: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Perform a async query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query to be executed.
|
||||||
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the query execution.
|
||||||
|
"""
|
||||||
if param.mode in ["local", "global", "hybrid"]:
|
if param.mode in ["local", "global", "hybrid"]:
|
||||||
response = await kg_query(
|
response = await kg_query(
|
||||||
query,
|
query,
|
||||||
|
@@ -295,8 +295,8 @@ async def extract_entities(
|
|||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entity_vdb: BaseVectorStorage,
|
entity_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
llm_response_cache: BaseKVStorage = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
) -> Union[BaseGraphStorage, None]:
|
) -> Union[BaseGraphStorage, None]:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
@@ -563,15 +563,15 @@ async def extract_entities(
|
|||||||
|
|
||||||
|
|
||||||
async def kg_query(
|
async def kg_query(
|
||||||
query,
|
query: str,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
prompt: str = "",
|
prompt: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
@@ -681,8 +681,8 @@ async def kg_query(
|
|||||||
async def extract_keywords_only(
|
async def extract_keywords_only(
|
||||||
text: str,
|
text: str,
|
||||||
param: QueryParam,
|
param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||||
@@ -778,8 +778,8 @@ async def mix_kg_vector_query(
|
|||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||||
@@ -1499,12 +1499,12 @@ def combine_contexts(entities, relationships, sources):
|
|||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
query,
|
query: str,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
):
|
):
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
@@ -1606,8 +1606,8 @@ async def kg_query_with_keywords(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Refactored kg_query that does NOT extract keywords by itself.
|
Refactored kg_query that does NOT extract keywords by itself.
|
||||||
|
@@ -128,7 +128,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
|||||||
return hashlib.md5(args_str.encode()).hexdigest()
|
return hashlib.md5(args_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content, prefix: str = ""):
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Compute a unique ID for a given content string.
|
||||||
|
|
||||||
|
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
||||||
|
"""
|
||||||
return prefix + md5(content.encode()).hexdigest()
|
return prefix + md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user