Merge pull request #782 from YanSte/code-cleaning
Code Cleanup & Maintenance Improvements
This commit is contained in:
@@ -428,9 +428,9 @@ And using a routine to process news documents.
|
||||
|
||||
```python
|
||||
rag = LightRAG(..)
|
||||
await rag.apipeline_enqueue_documents(string_or_strings)
|
||||
await rag.apipeline_enqueue_documents(input)
|
||||
# Your routine in loop
|
||||
await rag.apipeline_process_enqueue_documents(string_or_strings)
|
||||
await rag.apipeline_process_enqueue_documents(input)
|
||||
```
|
||||
|
||||
### Separate Keyword Extraction
|
||||
|
@@ -113,7 +113,24 @@ async def main():
|
||||
)
|
||||
|
||||
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
||||
rag.set_storage_client(db_client=oracle_db)
|
||||
|
||||
for storage in [
|
||||
rag.vector_db_storage_cls,
|
||||
rag.graph_storage_cls,
|
||||
rag.doc_status,
|
||||
rag.full_docs,
|
||||
rag.text_chunks,
|
||||
rag.llm_response_cache,
|
||||
rag.key_string_value_json_storage_cls,
|
||||
rag.chunks_vdb,
|
||||
rag.relationships_vdb,
|
||||
rag.entities_vdb,
|
||||
rag.graph_storage_cls,
|
||||
rag.chunk_entity_relation_graph,
|
||||
rag.llm_response_cache,
|
||||
]:
|
||||
# set client
|
||||
storage.db = oracle_db
|
||||
|
||||
# Extract and Insert into LightRAG storage
|
||||
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
||||
|
@@ -17,7 +17,9 @@ if not os.path.exists(WORKING_DIR):
|
||||
# ChromaDB Configuration
|
||||
CHROMADB_USE_LOCAL_PERSISTENT = False
|
||||
# Local PersistentClient Configuration
|
||||
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
|
||||
CHROMADB_LOCAL_PATH = os.environ.get(
|
||||
"CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
|
||||
)
|
||||
# Remote HttpClient Configuration
|
||||
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
||||
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
||||
|
@@ -1,358 +0,0 @@
|
||||
"""
|
||||
OpenWebui Lightrag Integration Tool
|
||||
==================================
|
||||
|
||||
This tool enables the integration and use of Lightrag within the OpenWebui environment,
|
||||
providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
|
||||
|
||||
Author: ParisNeo (parisneoai@gmail.com)
|
||||
Social:
|
||||
- Twitter: @ParisNeo_AI
|
||||
- Reddit: r/lollms
|
||||
- Instagram: https://www.instagram.com/parisneo_ai/
|
||||
|
||||
License: Apache 2.0
|
||||
Copyright (c) 2024-2025 ParisNeo
|
||||
|
||||
This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
|
||||
For more information, visit: https://github.com/ParisNeo/lollms
|
||||
|
||||
Requirements:
|
||||
- Python 3.8+
|
||||
- OpenWebui
|
||||
- Lightrag
|
||||
"""
|
||||
|
||||
# Tool version
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "ParisNeo"
|
||||
__author_email__ = "parisneoai@gmail.com"
|
||||
__description__ = "Lightrag integration for OpenWebui"
|
||||
|
||||
|
||||
import requests
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Callable, Any, Literal, Union, List, Tuple
|
||||
|
||||
|
||||
class StatusEventEmitter:
|
||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
async def emit(self, description="Unknown State", status="in_progress", done=False):
|
||||
if self.event_emitter:
|
||||
await self.event_emitter(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"status": status,
|
||||
"description": description,
|
||||
"done": done,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MessageEventEmitter:
|
||||
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
async def emit(self, content="Some message"):
|
||||
if self.event_emitter:
|
||||
await self.event_emitter(
|
||||
{
|
||||
"type": "message",
|
||||
"data": {
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Tools:
|
||||
class Valves(BaseModel):
|
||||
LIGHTRAG_SERVER_URL: str = Field(
|
||||
default="http://localhost:9621/query",
|
||||
description="The base URL for the LightRag server",
|
||||
)
|
||||
MODE: Literal["naive", "local", "global", "hybrid"] = Field(
|
||||
default="hybrid",
|
||||
description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
|
||||
)
|
||||
ONLY_NEED_CONTEXT: bool = Field(
|
||||
default=False,
|
||||
description="If True, only the context is needed from the LightRag response",
|
||||
)
|
||||
DEBUG_MODE: bool = Field(
|
||||
default=False,
|
||||
description="If True, debugging information will be emitted",
|
||||
)
|
||||
KEY: str = Field(
|
||||
default="",
|
||||
description="Optional Bearer Key for authentication",
|
||||
)
|
||||
MAX_ENTITIES: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of entities to keep",
|
||||
)
|
||||
MAX_RELATIONSHIPS: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of relationships to keep",
|
||||
)
|
||||
MAX_SOURCES: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of sources to keep",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self.valves = self.Valves()
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "LightRag-Tool/1.0",
|
||||
}
|
||||
|
||||
async def query_lightrag(
|
||||
self,
|
||||
query: str,
|
||||
__event_emitter__: Callable[[dict], Any] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Query the LightRag server and retrieve information.
|
||||
This function must be called before answering the user question
|
||||
:params query: The query string to send to the LightRag server.
|
||||
:return: The response from the LightRag server in Markdown format or raw response.
|
||||
"""
|
||||
self.status_emitter = StatusEventEmitter(__event_emitter__)
|
||||
self.message_emitter = MessageEventEmitter(__event_emitter__)
|
||||
|
||||
lightrag_url = self.valves.LIGHTRAG_SERVER_URL
|
||||
payload = {
|
||||
"query": query,
|
||||
"mode": str(self.valves.MODE),
|
||||
"stream": False,
|
||||
"only_need_context": self.valves.ONLY_NEED_CONTEXT,
|
||||
}
|
||||
await self.status_emitter.emit("Initializing Lightrag query..")
|
||||
|
||||
if self.valves.DEBUG_MODE:
|
||||
await self.message_emitter.emit(
|
||||
"### Debug Mode Active\n\nDebugging information will be displayed.\n"
|
||||
)
|
||||
await self.message_emitter.emit(
|
||||
"#### Payload Sent to LightRag Server\n```json\n"
|
||||
+ json.dumps(payload, indent=4)
|
||||
+ "\n```\n"
|
||||
)
|
||||
|
||||
# Add Bearer Key to headers if provided
|
||||
if self.valves.KEY:
|
||||
self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
|
||||
|
||||
try:
|
||||
await self.status_emitter.emit("Sending request to LightRag server")
|
||||
|
||||
response = requests.post(
|
||||
lightrag_url, json=payload, headers=self.headers, timeout=120
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
await self.status_emitter.emit(
|
||||
status="complete",
|
||||
description="LightRag query Succeeded",
|
||||
done=True,
|
||||
)
|
||||
|
||||
# Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
|
||||
if self.valves.ONLY_NEED_CONTEXT:
|
||||
try:
|
||||
if self.valves.DEBUG_MODE:
|
||||
await self.message_emitter.emit(
|
||||
"#### LightRag Server Response\n```json\n"
|
||||
+ data["response"]
|
||||
+ "\n```\n"
|
||||
)
|
||||
except Exception as ex:
|
||||
if self.valves.DEBUG_MODE:
|
||||
await self.message_emitter.emit(
|
||||
"#### Exception\n" + str(ex) + "\n"
|
||||
)
|
||||
return f"Exception: {ex}"
|
||||
return data["response"]
|
||||
else:
|
||||
if self.valves.DEBUG_MODE:
|
||||
await self.message_emitter.emit(
|
||||
"#### LightRag Server Response\n```json\n"
|
||||
+ data["response"]
|
||||
+ "\n```\n"
|
||||
)
|
||||
await self.status_emitter.emit("Lightrag query success")
|
||||
return data["response"]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
await self.status_emitter.emit(
|
||||
status="error",
|
||||
description=f"Error during LightRag query: {str(e)}",
|
||||
done=True,
|
||||
)
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
def extract_code_blocks(
|
||||
self, text: str, return_remaining_text: bool = False
|
||||
) -> Union[List[dict], Tuple[List[dict], str]]:
|
||||
"""
|
||||
This function extracts code blocks from a given text and optionally returns the text without code blocks.
|
||||
|
||||
Parameters:
|
||||
text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
|
||||
return_remaining_text (bool): If True, also returns the text with code blocks removed.
|
||||
|
||||
Returns:
|
||||
Union[List[dict], Tuple[List[dict], str]]:
|
||||
- If return_remaining_text is False: Returns only the list of code block dictionaries
|
||||
- If return_remaining_text is True: Returns a tuple containing:
|
||||
* List of code block dictionaries
|
||||
* String containing the text with all code blocks removed
|
||||
|
||||
Each code block dictionary contains:
|
||||
- 'index' (int): The index of the code block in the text
|
||||
- 'file_name' (str): The name of the file extracted from the preceding line, if available
|
||||
- 'content' (str): The content of the code block
|
||||
- 'type' (str): The type of the code block
|
||||
- 'is_complete' (bool): True if the block has a closing tag, False otherwise
|
||||
"""
|
||||
remaining = text
|
||||
bloc_index = 0
|
||||
first_index = 0
|
||||
indices = []
|
||||
text_without_blocks = text
|
||||
|
||||
# Find all code block delimiters
|
||||
while len(remaining) > 0:
|
||||
try:
|
||||
index = remaining.index("```")
|
||||
indices.append(index + first_index)
|
||||
remaining = remaining[index + 3 :]
|
||||
first_index += index + 3
|
||||
bloc_index += 1
|
||||
except Exception:
|
||||
if bloc_index % 2 == 1:
|
||||
index = len(remaining)
|
||||
indices.append(index)
|
||||
remaining = ""
|
||||
|
||||
code_blocks = []
|
||||
is_start = True
|
||||
|
||||
# Process code blocks and build text without blocks if requested
|
||||
if return_remaining_text:
|
||||
text_parts = []
|
||||
last_end = 0
|
||||
|
||||
for index, code_delimiter_position in enumerate(indices):
|
||||
if is_start:
|
||||
block_infos = {
|
||||
"index": len(code_blocks),
|
||||
"file_name": "",
|
||||
"section": "",
|
||||
"content": "",
|
||||
"type": "",
|
||||
"is_complete": False,
|
||||
}
|
||||
|
||||
# Store text before code block if returning remaining text
|
||||
if return_remaining_text:
|
||||
text_parts.append(text[last_end:code_delimiter_position].strip())
|
||||
|
||||
# Check the preceding line for file name
|
||||
preceding_text = text[:code_delimiter_position].strip().splitlines()
|
||||
if preceding_text:
|
||||
last_line = preceding_text[-1].strip()
|
||||
if last_line.startswith("<file_name>") and last_line.endswith(
|
||||
"</file_name>"
|
||||
):
|
||||
file_name = last_line[
|
||||
len("<file_name>") : -len("</file_name>")
|
||||
].strip()
|
||||
block_infos["file_name"] = file_name
|
||||
elif last_line.startswith("## filename:"):
|
||||
file_name = last_line[len("## filename:") :].strip()
|
||||
block_infos["file_name"] = file_name
|
||||
if last_line.startswith("<section>") and last_line.endswith(
|
||||
"</section>"
|
||||
):
|
||||
section = last_line[
|
||||
len("<section>") : -len("</section>")
|
||||
].strip()
|
||||
block_infos["section"] = section
|
||||
|
||||
sub_text = text[code_delimiter_position + 3 :]
|
||||
if len(sub_text) > 0:
|
||||
try:
|
||||
find_space = sub_text.index(" ")
|
||||
except Exception:
|
||||
find_space = int(1e10)
|
||||
try:
|
||||
find_return = sub_text.index("\n")
|
||||
except Exception:
|
||||
find_return = int(1e10)
|
||||
next_index = min(find_return, find_space)
|
||||
if "{" in sub_text[:next_index]:
|
||||
next_index = 0
|
||||
start_pos = next_index
|
||||
|
||||
if code_delimiter_position + 3 < len(text) and text[
|
||||
code_delimiter_position + 3
|
||||
] in ["\n", " ", "\t"]:
|
||||
block_infos["type"] = "language-specific"
|
||||
else:
|
||||
block_infos["type"] = sub_text[:next_index]
|
||||
|
||||
if index + 1 < len(indices):
|
||||
next_pos = indices[index + 1] - code_delimiter_position
|
||||
if (
|
||||
next_pos - 3 < len(sub_text)
|
||||
and sub_text[next_pos - 3] == "`"
|
||||
):
|
||||
block_infos["content"] = sub_text[
|
||||
start_pos : next_pos - 3
|
||||
].strip()
|
||||
block_infos["is_complete"] = True
|
||||
else:
|
||||
block_infos["content"] = sub_text[
|
||||
start_pos:next_pos
|
||||
].strip()
|
||||
block_infos["is_complete"] = False
|
||||
|
||||
if return_remaining_text:
|
||||
last_end = indices[index + 1] + 3
|
||||
else:
|
||||
block_infos["content"] = sub_text[start_pos:].strip()
|
||||
block_infos["is_complete"] = False
|
||||
|
||||
if return_remaining_text:
|
||||
last_end = len(text)
|
||||
|
||||
code_blocks.append(block_infos)
|
||||
is_start = False
|
||||
else:
|
||||
is_start = True
|
||||
|
||||
if return_remaining_text:
|
||||
# Add any remaining text after the last code block
|
||||
if last_end < len(text):
|
||||
text_parts.append(text[last_end:].strip())
|
||||
# Join all non-code parts with newlines
|
||||
text_without_blocks = "\n".join(filter(None, text_parts))
|
||||
return code_blocks, text_without_blocks
|
||||
|
||||
return code_blocks
|
||||
|
||||
def clean(self, csv_content: str):
|
||||
lines = csv_content.splitlines()
|
||||
if lines:
|
||||
# Remove spaces around headers and ensure no spaces between commas
|
||||
header = ",".join([col.strip() for col in lines[0].split(",")])
|
||||
lines[0] = header # Replace the first line with the cleaned header
|
||||
csv_content = "\n".join(lines)
|
||||
return csv_content
|
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@@ -69,7 +69,7 @@ class QueryParam:
|
||||
ll_keywords: list[str] = field(default_factory=list)
|
||||
"""List of low-level keywords to refine retrieval focus."""
|
||||
|
||||
conversation_history: list[dict[str, Any]] = field(default_factory=list)
|
||||
conversation_history: list[dict[str, str]] = field(default_factory=list)
|
||||
"""Stores past conversation history to maintain context.
|
||||
Format: [{"role": "user/assistant", "content": "message"}].
|
||||
"""
|
||||
@@ -83,19 +83,15 @@ class StorageNameSpace:
|
||||
namespace: str
|
||||
global_config: dict[str, Any]
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
"""Commit the storage operations after indexing"""
|
||||
pass
|
||||
|
||||
async def query_done_callback(self):
|
||||
"""Commit the storage operations after querying"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVectorStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
meta_fields: set = field(default_factory=set)
|
||||
meta_fields: set[str] = field(default_factory=set)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
@@ -106,12 +102,20 @@ class BaseVectorStorage(StorageNameSpace):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseKVStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
@@ -130,50 +134,75 @@ class BaseKVStorage(StorageNameSpace):
|
||||
|
||||
@dataclass
|
||||
class BaseGraphStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc = None
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
"""Check if a node exists in the graph."""
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Check if an edge exists in the graph."""
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get the degree of a node."""
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get the degree of an edge."""
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""Get a node by its id."""
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get an edge by its source and target node ids."""
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
) -> dict[str, str] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
"""Get all edges connected to a node."""
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
"""Upsert a node into the graph."""
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Upsert an edge into the graph."""
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""Delete a node from the graph."""
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
"""Embed nodes using an algorithm."""
|
||||
|
||||
async def embed_nodes(
|
||||
self, algorithm: str
|
||||
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||
|
||||
"""Get all labels in the graph."""
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
"""Get a knowledge graph of a node."""
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> KnowledgeGraph:
|
||||
@@ -205,9 +234,9 @@ class DocProcessingStatus:
|
||||
"""ISO format timestamp when document was created"""
|
||||
updated_at: str
|
||||
"""ISO format timestamp when document was last updated"""
|
||||
chunks_count: Optional[int] = None
|
||||
chunks_count: int | None = None
|
||||
"""Number of chunks after splitting, used for processing"""
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
"""Error message if failed"""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional metadata"""
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
from typing import Literal
|
||||
|
||||
|
@@ -67,7 +67,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
config.get(
|
||||
"auth_header_name", "X-Chroma-Token"
|
||||
): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
@@ -154,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
|
||||
query_embeddings=embedding.tolist()
|
||||
if not isinstance(embedding, list)
|
||||
else embedding,
|
||||
n_results=top_k * 2, # Request more results to allow for filtering
|
||||
include=["metadatas", "distances", "documents"],
|
||||
)
|
||||
|
@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
||||
await self.delete([entity_id])
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str):
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""
|
||||
Delete relations for a given entity by scanning metadata.
|
||||
"""
|
||||
|
@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str):
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
try:
|
||||
relations = [
|
||||
dp
|
||||
|
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import configparser
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Type, Union, cast
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
||||
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -92,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||
}
|
||||
|
||||
# Storage implementation environment variable without default value
|
||||
STORAGE_ENV_REQUIREMENTS = {
|
||||
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||
# KV Storage Implementations
|
||||
"JsonKVStorage": [],
|
||||
"MongoKVStorage": [],
|
||||
@@ -179,7 +181,7 @@ STORAGES = {
|
||||
}
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str):
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
@@ -188,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args, **kwargs):
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
@@ -305,7 +307,7 @@ class LightRAG:
|
||||
- random_seed: Seed value for reproducibility.
|
||||
"""
|
||||
|
||||
embedding_func: EmbeddingFunc = None
|
||||
embedding_func: EmbeddingFunc | None = None
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_batch_num: int = 32
|
||||
@@ -315,7 +317,7 @@ class LightRAG:
|
||||
"""Maximum number of concurrent embedding function calls."""
|
||||
|
||||
# LLM Configuration
|
||||
llm_model_func: callable = None
|
||||
llm_model_func: Callable[..., object] | None = None
|
||||
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
||||
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@@ -345,10 +347,8 @@ class LightRAG:
|
||||
|
||||
# Extensions
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
|
||||
# extension
|
||||
addon_params: dict[str, Any] = field(default_factory=dict)
|
||||
"""Dictionary for additional parameters and extensions."""
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
||||
convert_response_to_json
|
||||
)
|
||||
@@ -357,7 +357,7 @@ class LightRAG:
|
||||
chunking_func: Callable[
|
||||
[
|
||||
str,
|
||||
Optional[str],
|
||||
str | None,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
@@ -446,77 +446,74 @@ class LightRAG:
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init LLM
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
# Initialize all storages
|
||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
||||
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||
self._get_storage_class(self.kv_storage)
|
||||
)
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
|
||||
) # type: ignore
|
||||
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
||||
self.vector_storage
|
||||
)
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
|
||||
) # type: ignore
|
||||
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
||||
self.graph_storage
|
||||
)
|
||||
|
||||
self.key_string_value_json_storage_cls = partial(
|
||||
) # type: ignore
|
||||
self.key_string_value_json_storage_cls = partial( # type: ignore
|
||||
self.key_string_value_json_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.vector_db_storage_cls = partial(
|
||||
self.vector_db_storage_cls = partial( # type: ignore
|
||||
self.vector_db_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.graph_storage_cls = partial(
|
||||
self.graph_storage_cls = partial( # type: ignore
|
||||
self.graph_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
|
||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
|
||||
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
self.relationships_vdb = self.vector_db_storage_cls(
|
||||
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
||||
),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
|
||||
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
||||
),
|
||||
@@ -530,13 +527,12 @@ class LightRAG:
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
# What's for, Is this nessisary ?
|
||||
if self.llm_response_cache and hasattr(
|
||||
self.llm_response_cache, "global_config"
|
||||
):
|
||||
hashing_kv = self.llm_response_cache
|
||||
else:
|
||||
hashing_kv = self.key_string_value_json_storage_cls(
|
||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
||||
namespace=make_namespace(
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
@@ -545,7 +541,7 @@ class LightRAG:
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
self.llm_model_func,
|
||||
self.llm_model_func, # type: ignore
|
||||
hashing_kv=hashing_kv,
|
||||
**self.llm_model_kwargs,
|
||||
)
|
||||
@@ -562,68 +558,45 @@ class LightRAG:
|
||||
node_label=nodel_label, max_depth=max_depth
|
||||
)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> dict:
|
||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||
import_path = STORAGES[storage_name]
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
return storage_class
|
||||
|
||||
def set_storage_client(self, db_client):
|
||||
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
||||
# Inject db to storage implementation (only tested on Oracle Database)
|
||||
for storage in [
|
||||
self.vector_db_storage_cls,
|
||||
self.graph_storage_cls,
|
||||
self.doc_status,
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.key_string_value_json_storage_cls,
|
||||
self.chunks_vdb,
|
||||
self.relationships_vdb,
|
||||
self.entities_vdb,
|
||||
self.graph_storage_cls,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.llm_response_cache,
|
||||
]:
|
||||
# set client
|
||||
storage.db = db_client
|
||||
|
||||
def insert(
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
"""Sync Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
input: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
|
||||
self.ainsert(input, split_by_character, split_by_character_only)
|
||||
)
|
||||
|
||||
async def ainsert(
|
||||
self,
|
||||
string_or_strings: Union[str, list[str]],
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
):
|
||||
"""Async Insert documents with checkpoint support
|
||||
|
||||
Args:
|
||||
string_or_strings: Single document string or list of document strings
|
||||
input: Single document string or list of document strings
|
||||
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
||||
chunk_size, split the sub chunk by token size.
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
await self.apipeline_enqueue_documents(string_or_strings)
|
||||
await self.apipeline_enqueue_documents(input)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
@@ -680,7 +653,7 @@ class LightRAG:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
|
||||
async def apipeline_enqueue_documents(self, input: str | list[str]):
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
||||
@@ -689,11 +662,11 @@ class LightRAG:
|
||||
3. Filter out already processed documents
|
||||
4. Enqueue document in status
|
||||
"""
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
# 1. Remove duplicate contents from the list
|
||||
unique_contents = list(set(doc.strip() for doc in string_or_strings))
|
||||
unique_contents = list(set(doc.strip() for doc in input))
|
||||
|
||||
# 2. Generate document IDs and initial status
|
||||
new_docs: dict[str, Any] = {
|
||||
@@ -860,32 +833,32 @@ class LightRAG:
|
||||
raise e
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
tasks = [
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunks_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]
|
||||
if storage_inst is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def insert_custom_kg(self, custom_kg: dict):
|
||||
def insert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
||||
|
||||
async def ainsert_custom_kg(self, custom_kg: dict):
|
||||
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
|
||||
update_storage = False
|
||||
try:
|
||||
# Insert chunks into vector storage
|
||||
all_chunks_data = {}
|
||||
chunk_to_source_map = {}
|
||||
for chunk_data in custom_kg.get("chunks", []):
|
||||
all_chunks_data: dict[str, dict[str, str]] = {}
|
||||
chunk_to_source_map: dict[str, str] = {}
|
||||
for chunk_data in custom_kg.get("chunks", {}):
|
||||
chunk_content = chunk_data["content"]
|
||||
source_id = chunk_data["source_id"]
|
||||
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
||||
@@ -895,13 +868,13 @@ class LightRAG:
|
||||
chunk_to_source_map[source_id] = chunk_id
|
||||
update_storage = True
|
||||
|
||||
if self.chunks_vdb is not None and all_chunks_data:
|
||||
if all_chunks_data:
|
||||
await self.chunks_vdb.upsert(all_chunks_data)
|
||||
if self.text_chunks is not None and all_chunks_data:
|
||||
if all_chunks_data:
|
||||
await self.text_chunks.upsert(all_chunks_data)
|
||||
|
||||
# Insert entities into knowledge graph
|
||||
all_entities_data = []
|
||||
all_entities_data: list[dict[str, str]] = []
|
||||
for entity_data in custom_kg.get("entities", []):
|
||||
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
||||
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
||||
@@ -917,7 +890,7 @@ class LightRAG:
|
||||
)
|
||||
|
||||
# Prepare node data
|
||||
node_data = {
|
||||
node_data: dict[str, str] = {
|
||||
"entity_type": entity_type,
|
||||
"description": description,
|
||||
"source_id": source_id,
|
||||
@@ -931,7 +904,7 @@ class LightRAG:
|
||||
update_storage = True
|
||||
|
||||
# Insert relationships into knowledge graph
|
||||
all_relationships_data = []
|
||||
all_relationships_data: list[dict[str, str]] = []
|
||||
for relationship_data in custom_kg.get("relationships", []):
|
||||
src_id = f'"{relationship_data["src_id"].upper()}"'
|
||||
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
||||
@@ -973,7 +946,7 @@ class LightRAG:
|
||||
"source_id": source_id,
|
||||
},
|
||||
)
|
||||
edge_data = {
|
||||
edge_data: dict[str, str] = {
|
||||
"src_id": src_id,
|
||||
"tgt_id": tgt_id,
|
||||
"description": description,
|
||||
@@ -983,41 +956,68 @@ class LightRAG:
|
||||
update_storage = True
|
||||
|
||||
# Insert entities into vector storage if needed
|
||||
if self.entities_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
}
|
||||
for dp in all_entities_data
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"content": dp["entity_name"] + dp["description"],
|
||||
"entity_name": dp["entity_name"],
|
||||
}
|
||||
await self.entities_vdb.upsert(data_for_vdb)
|
||||
for dp in all_entities_data
|
||||
}
|
||||
await self.entities_vdb.upsert(data_for_vdb)
|
||||
|
||||
# Insert relationships into vector storage if needed
|
||||
if self.relationships_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
for dp in all_relationships_data
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"content": dp["keywords"]
|
||||
+ dp["src_id"]
|
||||
+ dp["tgt_id"]
|
||||
+ dp["description"],
|
||||
}
|
||||
await self.relationships_vdb.upsert(data_for_vdb)
|
||||
for dp in all_relationships_data
|
||||
}
|
||||
await self.relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
finally:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
|
||||
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
|
||||
def query(
|
||||
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
||||
) -> str | Iterator[str]:
|
||||
"""
|
||||
Perform a sync query.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
str: The result of the query execution.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, prompt, param))
|
||||
|
||||
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
||||
|
||||
async def aquery(
|
||||
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
|
||||
):
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Perform a async query.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
str: The result of the query execution.
|
||||
"""
|
||||
if param.mode in ["local", "global", "hybrid"]:
|
||||
response = await kg_query(
|
||||
query,
|
||||
@@ -1097,7 +1097,7 @@ class LightRAG:
|
||||
|
||||
async def aquery_with_separate_keyword_extraction(
|
||||
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
||||
):
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
||||
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
||||
@@ -1120,8 +1120,8 @@ class LightRAG:
|
||||
),
|
||||
)
|
||||
|
||||
param.hl_keywords = (hl_keywords,)
|
||||
param.ll_keywords = (ll_keywords,)
|
||||
param.hl_keywords = hl_keywords
|
||||
param.ll_keywords = ll_keywords
|
||||
|
||||
# ---------------------
|
||||
# STEP 2: Final Query Logic
|
||||
@@ -1149,7 +1149,7 @@ class LightRAG:
|
||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||
),
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_funcne,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
@@ -1198,12 +1198,7 @@ class LightRAG:
|
||||
return response
|
||||
|
||||
async def _query_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [self.llm_response_cache]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
|
||||
def delete_by_entity(self, entity_name: str):
|
||||
loop = always_get_an_event_loop()
|
||||
@@ -1225,16 +1220,16 @@ class LightRAG:
|
||||
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
||||
|
||||
async def _delete_by_entity_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]:
|
||||
if storage_inst is None:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||
for storage_inst in [ # type: ignore
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.chunk_entity_relation_graph,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
||||
"""Get summary of document content
|
||||
@@ -1259,7 +1254,7 @@ class LightRAG:
|
||||
"""
|
||||
return await self.doc_status.get_status_counts()
|
||||
|
||||
async def adelete_by_doc_id(self, doc_id: str):
|
||||
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
||||
"""Delete a document and all its related data
|
||||
|
||||
Args:
|
||||
@@ -1276,6 +1271,9 @@ class LightRAG:
|
||||
|
||||
# 2. Get all related chunks
|
||||
chunks = await self.text_chunks.get_by_id(doc_id)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
chunk_ids = list(chunks.keys())
|
||||
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
||||
|
||||
@@ -1446,13 +1444,9 @@ class LightRAG:
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting document {doc_id}: {e}")
|
||||
|
||||
def delete_by_doc_id(self, doc_id: str):
|
||||
"""Synchronous version of adelete"""
|
||||
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
||||
|
||||
async def get_entity_info(
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
):
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of an entity
|
||||
|
||||
Args:
|
||||
@@ -1472,7 +1466,7 @@ class LightRAG:
|
||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||
source_id = node_data.get("source_id") if node_data else None
|
||||
|
||||
result = {
|
||||
result: dict[str, str | None | dict[str, str]] = {
|
||||
"entity_name": entity_name,
|
||||
"source_id": source_id,
|
||||
"graph_data": node_data,
|
||||
@@ -1486,21 +1480,6 @@ class LightRAG:
|
||||
|
||||
return result
|
||||
|
||||
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
|
||||
"""Synchronous version of getting entity information
|
||||
|
||||
Args:
|
||||
entity_name: Entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
"""
|
||||
try:
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
|
||||
finally:
|
||||
tracemalloc.stop()
|
||||
|
||||
async def get_relation_info(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
@@ -1528,7 +1507,7 @@ class LightRAG:
|
||||
)
|
||||
source_id = edge_data.get("source_id") if edge_data else None
|
||||
|
||||
result = {
|
||||
result: dict[str, str | None | dict[str, str]] = {
|
||||
"src_entity": src_entity,
|
||||
"tgt_entity": tgt_entity,
|
||||
"source_id": source_id,
|
||||
@@ -1542,23 +1521,3 @@ class LightRAG:
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
|
||||
return result
|
||||
|
||||
def get_relation_info_sync(
|
||||
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
||||
):
|
||||
"""Synchronous version of getting relationship information
|
||||
|
||||
Args:
|
||||
src_entity: Source entity name (no need for quotes)
|
||||
tgt_entity: Target entity name (no need for quotes)
|
||||
include_vector_data: Whether to include data from the vector database
|
||||
"""
|
||||
try:
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
return asyncio.run(
|
||||
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
|
||||
)
|
||||
finally:
|
||||
tracemalloc.stop()
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from typing import List, Dict, Callable, Any
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -23,7 +25,7 @@ class Model(BaseModel):
|
||||
...,
|
||||
description="A function that generates the response from the llm. The response must be a string",
|
||||
)
|
||||
kwargs: Dict[str, Any] = Field(
|
||||
kwargs: dict[str, Any] = Field(
|
||||
...,
|
||||
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
||||
)
|
||||
@@ -57,7 +59,7 @@ class MultiModel:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, models: List[Model]):
|
||||
def __init__(self, models: list[Model]):
|
||||
self._models = models
|
||||
self._current_model = 0
|
||||
|
||||
@@ -66,7 +68,11 @@ class MultiModel:
|
||||
return self._models[self._current_model]
|
||||
|
||||
async def llm_model_func(
|
||||
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] = [],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from typing import Any, Union
|
||||
from typing import Any, AsyncIterator
|
||||
from collections import Counter, defaultdict
|
||||
from .utils import (
|
||||
logger,
|
||||
@@ -36,7 +38,7 @@ import time
|
||||
|
||||
def chunking_by_token_size(
|
||||
content: str,
|
||||
split_by_character: Union[str, None] = None,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
overlap_token_size: int = 128,
|
||||
max_token_size: int = 1024,
|
||||
@@ -335,9 +337,9 @@ async def extract_entities(
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
llm_response_cache: BaseKVStorage = None,
|
||||
) -> Union[BaseGraphStorage, None]:
|
||||
global_config: dict[str, str],
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> BaseGraphStorage | None:
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
@@ -603,15 +605,15 @@ async def extract_entities(
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query,
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
prompt: str = "",
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> str:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
@@ -721,8 +723,8 @@ async def kg_query(
|
||||
async def extract_keywords_only(
|
||||
text: str,
|
||||
param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||
@@ -818,9 +820,9 @@ async def mix_kg_vector_query(
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
|
||||
@@ -1539,13 +1541,13 @@ def combine_contexts(entities, relationships, sources):
|
||||
|
||||
|
||||
async def naive_query(
|
||||
query,
|
||||
query: str,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
):
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
@@ -1646,9 +1648,9 @@ async def kg_query_with_keywords(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Refactored kg_query that does NOT extract keywords by itself.
|
||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
PROMPTS = {}
|
||||
|
@@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class GPTKeywordExtractionFormat(BaseModel):
|
||||
high_level_keywords: List[str]
|
||||
low_level_keywords: List[str]
|
||||
high_level_keywords: list[str]
|
||||
low_level_keywords: list[str]
|
||||
|
||||
|
||||
class KnowledgeGraphNode(BaseModel):
|
||||
id: str
|
||||
labels: List[str]
|
||||
properties: Dict[str, Any] # anything else goes here
|
||||
labels: list[str]
|
||||
properties: dict[str, Any] # anything else goes here
|
||||
|
||||
|
||||
class KnowledgeGraphEdge(BaseModel):
|
||||
@@ -18,9 +20,9 @@ class KnowledgeGraphEdge(BaseModel):
|
||||
type: Optional[str]
|
||||
source: str # id of source node
|
||||
target: str # id of target node
|
||||
properties: Dict[str, Any] # anything else goes here
|
||||
properties: dict[str, Any] # anything else goes here
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
nodes: List[KnowledgeGraphNode] = []
|
||||
edges: List[KnowledgeGraphEdge] = []
|
||||
nodes: list[KnowledgeGraphNode] = []
|
||||
edges: list[KnowledgeGraphEdge] = []
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import io
|
||||
@@ -9,7 +11,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Union, List, Optional
|
||||
from typing import Any, Callable
|
||||
import xml.etree.ElementTree as ET
|
||||
import bs4
|
||||
|
||||
@@ -67,12 +69,12 @@ class EmbeddingFunc:
|
||||
|
||||
@dataclass
|
||||
class ReasoningResponse:
|
||||
reasoning_content: str
|
||||
reasoning_content: str | None
|
||||
response_content: str
|
||||
tag: str
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
def locate_json_string_body_from_string(content: str) -> str | None:
|
||||
"""Locate the JSON string body from a string"""
|
||||
try:
|
||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
||||
raise e from None
|
||||
|
||||
|
||||
def compute_args_hash(*args, cache_type: str = None) -> str:
|
||||
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||
"""Compute a hash for the given arguments.
|
||||
Args:
|
||||
*args: Arguments to hash
|
||||
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
||||
return hashlib.md5(args_str.encode()).hexdigest()
|
||||
|
||||
|
||||
def compute_mdhash_id(content, prefix: str = ""):
|
||||
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||
"""
|
||||
Compute a unique ID for a given content string.
|
||||
|
||||
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
||||
"""
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
|
||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||
|
||||
|
||||
def is_float_regex(value):
|
||||
def is_float_regex(value: str) -> bool:
|
||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||
|
||||
|
||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
||||
def truncate_list_by_token_size(
|
||||
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
||||
) -> list[int]:
|
||||
"""Truncate a list of data by token size"""
|
||||
if max_token_size <= 0:
|
||||
return []
|
||||
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
||||
return list_data
|
||||
|
||||
|
||||
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(
|
||||
output,
|
||||
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
||||
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
||||
# Clean the string by removing NUL characters
|
||||
cleaned_string = csv_string.replace("\0", "")
|
||||
|
||||
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
|
||||
return None
|
||||
|
||||
|
||||
def process_combine_contexts(hl, ll):
|
||||
def process_combine_contexts(hl: str, ll: str):
|
||||
header = None
|
||||
list_hl = csv_string_to_list(hl.strip())
|
||||
list_ll = csv_string_to_list(ll.strip())
|
||||
@@ -375,7 +384,7 @@ async def get_best_cached_response(
|
||||
llm_func=None,
|
||||
original_prompt=None,
|
||||
cache_type=None,
|
||||
) -> Union[str, None]:
|
||||
) -> str | None:
|
||||
logger.debug(
|
||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||
)
|
||||
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
||||
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
|
||||
"""Quantize embedding to specified bits"""
|
||||
# Convert list to numpy array if needed
|
||||
if isinstance(embedding, list):
|
||||
@@ -570,9 +579,9 @@ class CacheData:
|
||||
args_hash: str
|
||||
content: str
|
||||
prompt: str
|
||||
quantized: Optional[np.ndarray] = None
|
||||
min_val: Optional[float] = None
|
||||
max_val: Optional[float] = None
|
||||
quantized: np.ndarray | None = None
|
||||
min_val: float | None = None
|
||||
max_val: float | None = None
|
||||
mode: str = "default"
|
||||
cache_type: str = "query"
|
||||
|
||||
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
|
||||
def get_conversation_turns(
|
||||
conversation_history: list[dict[str, Any]], num_turns: int
|
||||
) -> str:
|
||||
"""
|
||||
Process conversation history to get the specified number of complete turns.
|
||||
|
||||
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
||||
Formatted string of the conversation history
|
||||
"""
|
||||
# Group messages into turns
|
||||
turns = []
|
||||
messages = []
|
||||
turns: list[list[dict[str, Any]]] = []
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
# First, filter out keyword extraction messages
|
||||
for msg in conversation_history:
|
||||
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
||||
turns = turns[-num_turns:]
|
||||
|
||||
# Format the turns into a string
|
||||
formatted_turns = []
|
||||
formatted_turns: list[str] = []
|
||||
for turn in turns:
|
||||
formatted_turns.extend(
|
||||
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
||||
|
Reference in New Issue
Block a user