diff --git a/README.md b/README.md
index 62f21a65..487c65f5 100644
--- a/README.md
+++ b/README.md
@@ -428,9 +428,9 @@ And using a routine to process news documents.
```python
rag = LightRAG(..)
-await rag.apipeline_enqueue_documents(string_or_strings)
+await rag.apipeline_enqueue_documents(input)
# Your routine in loop
-await rag.apipeline_process_enqueue_documents(string_or_strings)
+await rag.apipeline_process_enqueue_documents(input)
```
### Separate Keyword Extraction
diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py
index f5269fae..9c90424e 100644
--- a/examples/lightrag_oracle_demo.py
+++ b/examples/lightrag_oracle_demo.py
@@ -113,7 +113,24 @@ async def main():
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
- rag.set_storage_client(db_client=oracle_db)
+
+ for storage in [
+ rag.vector_db_storage_cls,
+ rag.graph_storage_cls,
+ rag.doc_status,
+ rag.full_docs,
+ rag.text_chunks,
+ rag.llm_response_cache,
+ rag.key_string_value_json_storage_cls,
+ rag.chunks_vdb,
+ rag.relationships_vdb,
+ rag.entities_vdb,
+ rag.graph_storage_cls,
+ rag.chunk_entity_relation_graph,
+ rag.llm_response_cache,
+ ]:
+ # set client
+ storage.db = oracle_db
# Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
diff --git a/external_bindings/OpenWebuiTool/openwebui_tool.py b/external_bindings/OpenWebuiTool/openwebui_tool.py
deleted file mode 100644
index 8df3109c..00000000
--- a/external_bindings/OpenWebuiTool/openwebui_tool.py
+++ /dev/null
@@ -1,358 +0,0 @@
-"""
-OpenWebui Lightrag Integration Tool
-==================================
-
-This tool enables the integration and use of Lightrag within the OpenWebui environment,
-providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
-
-Author: ParisNeo (parisneoai@gmail.com)
-Social:
- - Twitter: @ParisNeo_AI
- - Reddit: r/lollms
- - Instagram: https://www.instagram.com/parisneo_ai/
-
-License: Apache 2.0
-Copyright (c) 2024-2025 ParisNeo
-
-This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
-For more information, visit: https://github.com/ParisNeo/lollms
-
-Requirements:
- - Python 3.8+
- - OpenWebui
- - Lightrag
-"""
-
-# Tool version
-__version__ = "1.0.0"
-__author__ = "ParisNeo"
-__author_email__ = "parisneoai@gmail.com"
-__description__ = "Lightrag integration for OpenWebui"
-
-
-import requests
-import json
-from pydantic import BaseModel, Field
-from typing import Callable, Any, Literal, Union, List, Tuple
-
-
-class StatusEventEmitter:
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
- self.event_emitter = event_emitter
-
- async def emit(self, description="Unknown State", status="in_progress", done=False):
- if self.event_emitter:
- await self.event_emitter(
- {
- "type": "status",
- "data": {
- "status": status,
- "description": description,
- "done": done,
- },
- }
- )
-
-
-class MessageEventEmitter:
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
- self.event_emitter = event_emitter
-
- async def emit(self, content="Some message"):
- if self.event_emitter:
- await self.event_emitter(
- {
- "type": "message",
- "data": {
- "content": content,
- },
- }
- )
-
-
-class Tools:
- class Valves(BaseModel):
- LIGHTRAG_SERVER_URL: str = Field(
- default="http://localhost:9621/query",
- description="The base URL for the LightRag server",
- )
- MODE: Literal["naive", "local", "global", "hybrid"] = Field(
- default="hybrid",
- description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
- )
- ONLY_NEED_CONTEXT: bool = Field(
- default=False,
- description="If True, only the context is needed from the LightRag response",
- )
- DEBUG_MODE: bool = Field(
- default=False,
- description="If True, debugging information will be emitted",
- )
- KEY: str = Field(
- default="",
- description="Optional Bearer Key for authentication",
- )
- MAX_ENTITIES: int = Field(
- default=5,
- description="Maximum number of entities to keep",
- )
- MAX_RELATIONSHIPS: int = Field(
- default=5,
- description="Maximum number of relationships to keep",
- )
- MAX_SOURCES: int = Field(
- default=3,
- description="Maximum number of sources to keep",
- )
-
- def __init__(self):
- self.valves = self.Valves()
- self.headers = {
- "Content-Type": "application/json",
- "User-Agent": "LightRag-Tool/1.0",
- }
-
- async def query_lightrag(
- self,
- query: str,
- __event_emitter__: Callable[[dict], Any] = None,
- ) -> str:
- """
- Query the LightRag server and retrieve information.
- This function must be called before answering the user question
- :params query: The query string to send to the LightRag server.
- :return: The response from the LightRag server in Markdown format or raw response.
- """
- self.status_emitter = StatusEventEmitter(__event_emitter__)
- self.message_emitter = MessageEventEmitter(__event_emitter__)
-
- lightrag_url = self.valves.LIGHTRAG_SERVER_URL
- payload = {
- "query": query,
- "mode": str(self.valves.MODE),
- "stream": False,
- "only_need_context": self.valves.ONLY_NEED_CONTEXT,
- }
- await self.status_emitter.emit("Initializing Lightrag query..")
-
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "### Debug Mode Active\n\nDebugging information will be displayed.\n"
- )
- await self.message_emitter.emit(
- "#### Payload Sent to LightRag Server\n```json\n"
- + json.dumps(payload, indent=4)
- + "\n```\n"
- )
-
- # Add Bearer Key to headers if provided
- if self.valves.KEY:
- self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
-
- try:
- await self.status_emitter.emit("Sending request to LightRag server")
-
- response = requests.post(
- lightrag_url, json=payload, headers=self.headers, timeout=120
- )
- response.raise_for_status()
- data = response.json()
- await self.status_emitter.emit(
- status="complete",
- description="LightRag query Succeeded",
- done=True,
- )
-
- # Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
- if self.valves.ONLY_NEED_CONTEXT:
- try:
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "#### LightRag Server Response\n```json\n"
- + data["response"]
- + "\n```\n"
- )
- except Exception as ex:
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "#### Exception\n" + str(ex) + "\n"
- )
- return f"Exception: {ex}"
- return data["response"]
- else:
- if self.valves.DEBUG_MODE:
- await self.message_emitter.emit(
- "#### LightRag Server Response\n```json\n"
- + data["response"]
- + "\n```\n"
- )
- await self.status_emitter.emit("Lightrag query success")
- return data["response"]
-
- except requests.exceptions.RequestException as e:
- await self.status_emitter.emit(
- status="error",
- description=f"Error during LightRag query: {str(e)}",
- done=True,
- )
- return json.dumps({"error": str(e)})
-
- def extract_code_blocks(
- self, text: str, return_remaining_text: bool = False
- ) -> Union[List[dict], Tuple[List[dict], str]]:
- """
- This function extracts code blocks from a given text and optionally returns the text without code blocks.
-
- Parameters:
- text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
- return_remaining_text (bool): If True, also returns the text with code blocks removed.
-
- Returns:
- Union[List[dict], Tuple[List[dict], str]]:
- - If return_remaining_text is False: Returns only the list of code block dictionaries
- - If return_remaining_text is True: Returns a tuple containing:
- * List of code block dictionaries
- * String containing the text with all code blocks removed
-
- Each code block dictionary contains:
- - 'index' (int): The index of the code block in the text
- - 'file_name' (str): The name of the file extracted from the preceding line, if available
- - 'content' (str): The content of the code block
- - 'type' (str): The type of the code block
- - 'is_complete' (bool): True if the block has a closing tag, False otherwise
- """
- remaining = text
- bloc_index = 0
- first_index = 0
- indices = []
- text_without_blocks = text
-
- # Find all code block delimiters
- while len(remaining) > 0:
- try:
- index = remaining.index("```")
- indices.append(index + first_index)
- remaining = remaining[index + 3 :]
- first_index += index + 3
- bloc_index += 1
- except Exception:
- if bloc_index % 2 == 1:
- index = len(remaining)
- indices.append(index)
- remaining = ""
-
- code_blocks = []
- is_start = True
-
- # Process code blocks and build text without blocks if requested
- if return_remaining_text:
- text_parts = []
- last_end = 0
-
- for index, code_delimiter_position in enumerate(indices):
- if is_start:
- block_infos = {
- "index": len(code_blocks),
- "file_name": "",
- "section": "",
- "content": "",
- "type": "",
- "is_complete": False,
- }
-
- # Store text before code block if returning remaining text
- if return_remaining_text:
- text_parts.append(text[last_end:code_delimiter_position].strip())
-
- # Check the preceding line for file name
- preceding_text = text[:code_delimiter_position].strip().splitlines()
- if preceding_text:
- last_line = preceding_text[-1].strip()
- if last_line.startswith("") and last_line.endswith(
- ""
- ):
- file_name = last_line[
- len("") : -len("")
- ].strip()
- block_infos["file_name"] = file_name
- elif last_line.startswith("## filename:"):
- file_name = last_line[len("## filename:") :].strip()
- block_infos["file_name"] = file_name
- if last_line.startswith("") and last_line.endswith(
- ""
- ):
- section = last_line[
- len("")
- ].strip()
- block_infos["section"] = section
-
- sub_text = text[code_delimiter_position + 3 :]
- if len(sub_text) > 0:
- try:
- find_space = sub_text.index(" ")
- except Exception:
- find_space = int(1e10)
- try:
- find_return = sub_text.index("\n")
- except Exception:
- find_return = int(1e10)
- next_index = min(find_return, find_space)
- if "{" in sub_text[:next_index]:
- next_index = 0
- start_pos = next_index
-
- if code_delimiter_position + 3 < len(text) and text[
- code_delimiter_position + 3
- ] in ["\n", " ", "\t"]:
- block_infos["type"] = "language-specific"
- else:
- block_infos["type"] = sub_text[:next_index]
-
- if index + 1 < len(indices):
- next_pos = indices[index + 1] - code_delimiter_position
- if (
- next_pos - 3 < len(sub_text)
- and sub_text[next_pos - 3] == "`"
- ):
- block_infos["content"] = sub_text[
- start_pos : next_pos - 3
- ].strip()
- block_infos["is_complete"] = True
- else:
- block_infos["content"] = sub_text[
- start_pos:next_pos
- ].strip()
- block_infos["is_complete"] = False
-
- if return_remaining_text:
- last_end = indices[index + 1] + 3
- else:
- block_infos["content"] = sub_text[start_pos:].strip()
- block_infos["is_complete"] = False
-
- if return_remaining_text:
- last_end = len(text)
-
- code_blocks.append(block_infos)
- is_start = False
- else:
- is_start = True
-
- if return_remaining_text:
- # Add any remaining text after the last code block
- if last_end < len(text):
- text_parts.append(text[last_end:].strip())
- # Join all non-code parts with newlines
- text_without_blocks = "\n".join(filter(None, text_parts))
- return code_blocks, text_without_blocks
-
- return code_blocks
-
- def clean(self, csv_content: str):
- lines = csv_content.splitlines()
- if lines:
- # Remove spaces around headers and ensure no spaces between commas
- header = ",".join([col.strip() for col in lines[0].split(",")])
- lines[0] = header # Replace the first line with the cleaned header
- csv_content = "\n".join(lines)
- return csv_content
diff --git a/lightrag/base.py b/lightrag/base.py
index e75167c4..8e6a212d 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -83,11 +83,11 @@ class StorageNameSpace:
namespace: str
global_config: dict[str, Any]
- async def index_done_callback(self):
+ async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing"""
pass
- async def query_done_callback(self):
+ async def query_done_callback(self) -> None:
"""Commit the storage operations after querying"""
pass
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index fcea2c57..b4426cd7 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -6,7 +6,7 @@ import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
-from typing import Any, Callable, Optional, Type, Union, cast
+from typing import Any, Callable, Optional, Union, cast
from .base import (
BaseGraphStorage,
@@ -304,7 +304,7 @@ class LightRAG:
- random_seed: Seed value for reproducibility.
"""
- embedding_func: Union[EmbeddingFunc, None] = None
+ embedding_func: EmbeddingFunc | None = None
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32
@@ -344,10 +344,8 @@ class LightRAG:
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
- """Dictionary for additional parameters and extensions."""
- # extension
- addon_params: dict[str, Any] = field(default_factory=dict)
+ """Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
)
@@ -445,77 +443,74 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
- self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
# Initialize all storages
- self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore
+ self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage)
- )
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore
+ ) # type: ignore
+ self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
- )
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore
+ ) # type: ignore
+ self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
- )
-
- self.key_string_value_json_storage_cls = partial( # type: ignore
+ ) # type: ignore
+ self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config
)
-
- self.vector_db_storage_cls = partial( # type: ignore
+ self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config
)
-
- self.graph_storage_cls = partial( # type: ignore
+ self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
- self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore
+ self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
- self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
+ self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
embedding_func=self.embedding_func,
)
- self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
+ self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func,
)
- self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
+ self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
embedding_func=self.embedding_func,
)
- self.entities_vdb = self.vector_db_storage_cls( # type: ignore
+ self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
- self.relationships_vdb = self.vector_db_storage_cls( # type: ignore
+ self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
- self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
+ self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
@@ -535,16 +530,16 @@ class LightRAG:
):
hashing_kv = self.llm_response_cache
else:
- hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
+ hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
-
+
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
- self.llm_model_func, # type: ignore
+ self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
@@ -836,32 +831,32 @@ class LightRAG:
raise e
async def _insert_done(self):
- tasks = []
- for storage_inst in [
- self.full_docs,
- self.text_chunks,
- self.llm_response_cache,
- self.entities_vdb,
- self.relationships_vdb,
- self.chunks_vdb,
- self.chunk_entity_relation_graph,
- ]:
- if storage_inst is None:
- continue
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
+ tasks = [
+ cast(StorageNameSpace, storage_inst).index_done_callback()
+ for storage_inst in [ # type: ignore
+ self.full_docs,
+ self.text_chunks,
+ self.llm_response_cache,
+ self.entities_vdb,
+ self.relationships_vdb,
+ self.chunks_vdb,
+ self.chunk_entity_relation_graph,
+ ]
+ if storage_inst is not None
+ ]
await asyncio.gather(*tasks)
- def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
+ def insert_custom_kg(self, custom_kg: dict[str, Any]):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
- async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
+ async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
update_storage = False
try:
# Insert chunks into vector storage
- all_chunks_data = {}
- chunk_to_source_map = {}
- for chunk_data in custom_kg.get("chunks", []):
+ all_chunks_data: dict[str, dict[str, str]] = {}
+ chunk_to_source_map: dict[str, str] = {}
+ for chunk_data in custom_kg.get("chunks", {}):
chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -871,13 +866,13 @@ class LightRAG:
chunk_to_source_map[source_id] = chunk_id
update_storage = True
- if self.chunks_vdb is not None and all_chunks_data:
+ if all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data)
- if self.text_chunks is not None and all_chunks_data:
+ if all_chunks_data:
await self.text_chunks.upsert(all_chunks_data)
# Insert entities into knowledge graph
- all_entities_data = []
+ all_entities_data: list[dict[str, str]] = []
for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -893,7 +888,7 @@ class LightRAG:
)
# Prepare node data
- node_data = {
+ node_data: dict[str, str] = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
@@ -907,7 +902,7 @@ class LightRAG:
update_storage = True
# Insert relationships into knowledge graph
- all_relationships_data = []
+ all_relationships_data: list[dict[str, str]] = []
for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -949,7 +944,7 @@ class LightRAG:
"source_id": source_id,
},
)
- edge_data = {
+ edge_data: dict[str, str] = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
@@ -959,19 +954,17 @@ class LightRAG:
update_storage = True
# Insert entities into vector storage if needed
- if self.entities_vdb is not None:
- data_for_vdb = {
+ data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
- await self.entities_vdb.upsert(data_for_vdb)
+ await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
- if self.relationships_vdb is not None:
- data_for_vdb = {
+ data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
@@ -982,18 +975,49 @@ class LightRAG:
}
for dp in all_relationships_data
}
- await self.relationships_vdb.upsert(data_for_vdb)
+ await self.relationships_vdb.upsert(data_for_vdb)
+
finally:
if update_storage:
await self._insert_done()
- def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
+ def query(
+ self,
+ query: str,
+ param: QueryParam = QueryParam(),
+ prompt: str | None = None
+ ) -> str:
+ """
+ Perform a sync query.
+
+ Args:
+ query (str): The query to be executed.
+ param (QueryParam): Configuration parameters for query execution.
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
+
+ Returns:
+ str: The result of the query execution.
+ """
loop = always_get_an_event_loop()
- return loop.run_until_complete(self.aquery(query, prompt, param))
+ return loop.run_until_complete(self.aquery(query, param, prompt))
async def aquery(
- self, query: str, prompt: str = "", param: QueryParam = QueryParam()
- ):
+ self,
+ query: str,
+ param: QueryParam = QueryParam(),
+ prompt: str | None = None,
+ ) -> str:
+ """
+ Perform a async query.
+
+ Args:
+ query (str): The query to be executed.
+ param (QueryParam): Configuration parameters for query execution.
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
+
+ Returns:
+ str: The result of the query execution.
+ """
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query,
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 04aad0d4..a961cfd9 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -295,8 +295,8 @@ async def extract_entities(
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
- global_config: dict,
- llm_response_cache: BaseKVStorage = None,
+ global_config: dict[str, str],
+ llm_response_cache: BaseKVStorage | None = None,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -563,15 +563,15 @@ async def extract_entities(
async def kg_query(
- query,
+ query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
- prompt: str = "",
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
+ prompt: str | None = None,
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -681,8 +681,8 @@ async def kg_query(
async def extract_keywords_only(
text: str,
param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -778,8 +778,8 @@ async def mix_kg_vector_query(
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
) -> str:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1499,12 +1499,12 @@ def combine_contexts(entities, relationships, sources):
async def naive_query(
- query,
+ query: str,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
):
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -1606,8 +1606,8 @@ async def kg_query_with_keywords(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict,
- hashing_kv: BaseKVStorage = None,
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
) -> str:
"""
Refactored kg_query that does NOT extract keywords by itself.
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 9df325ca..c94e23cb 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -128,7 +128,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
return hashlib.md5(args_str.encode()).hexdigest()
-def compute_mdhash_id(content, prefix: str = ""):
+def compute_mdhash_id(content: str, prefix: str = "") -> str:
+ """
+ Compute a unique ID for a given content string.
+
+ The ID is a combination of the given prefix and the MD5 hash of the content string.
+ """
return prefix + md5(content.encode()).hexdigest()