Merge pull request #782 from YanSte/code-cleaning

Code Cleanup & Maintenance Improvements
This commit is contained in:
zrguo
2025-02-16 19:46:42 +08:00
committed by GitHub
16 changed files with 299 additions and 619 deletions

View File

@@ -428,9 +428,9 @@ And using a routine to process news documents.
```python
rag = LightRAG(..)
await rag.apipeline_enqueue_documents(string_or_strings)
await rag.apipeline_enqueue_documents(input)
# Your routine in loop
await rag.apipeline_process_enqueue_documents(string_or_strings)
await rag.apipeline_process_enqueue_documents(input)
```
### Separate Keyword Extraction

View File

@@ -113,7 +113,24 @@ async def main():
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.set_storage_client(db_client=oracle_db)
for storage in [
rag.vector_db_storage_cls,
rag.graph_storage_cls,
rag.doc_status,
rag.full_docs,
rag.text_chunks,
rag.llm_response_cache,
rag.key_string_value_json_storage_cls,
rag.chunks_vdb,
rag.relationships_vdb,
rag.entities_vdb,
rag.graph_storage_cls,
rag.chunk_entity_relation_graph,
rag.llm_response_cache,
]:
# set client
storage.db = oracle_db
# Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:

View File

@@ -17,7 +17,9 @@ if not os.path.exists(WORKING_DIR):
# ChromaDB Configuration
CHROMADB_USE_LOCAL_PERSISTENT = False
# Local PersistentClient Configuration
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
CHROMADB_LOCAL_PATH = os.environ.get(
"CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
)
# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))

View File

@@ -1,358 +0,0 @@
"""
OpenWebui Lightrag Integration Tool
==================================
This tool enables the integration and use of Lightrag within the OpenWebui environment,
providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
Author: ParisNeo (parisneoai@gmail.com)
Social:
- Twitter: @ParisNeo_AI
- Reddit: r/lollms
- Instagram: https://www.instagram.com/parisneo_ai/
License: Apache 2.0
Copyright (c) 2024-2025 ParisNeo
This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
For more information, visit: https://github.com/ParisNeo/lollms
Requirements:
- Python 3.8+
- OpenWebui
- Lightrag
"""
# Tool version
__version__ = "1.0.0"
__author__ = "ParisNeo"
__author_email__ = "parisneoai@gmail.com"
__description__ = "Lightrag integration for OpenWebui"
import requests
import json
from pydantic import BaseModel, Field
from typing import Callable, Any, Literal, Union, List, Tuple
class StatusEventEmitter:
def __init__(self, event_emitter: Callable[[dict], Any] = None):
self.event_emitter = event_emitter
async def emit(self, description="Unknown State", status="in_progress", done=False):
if self.event_emitter:
await self.event_emitter(
{
"type": "status",
"data": {
"status": status,
"description": description,
"done": done,
},
}
)
class MessageEventEmitter:
def __init__(self, event_emitter: Callable[[dict], Any] = None):
self.event_emitter = event_emitter
async def emit(self, content="Some message"):
if self.event_emitter:
await self.event_emitter(
{
"type": "message",
"data": {
"content": content,
},
}
)
class Tools:
class Valves(BaseModel):
LIGHTRAG_SERVER_URL: str = Field(
default="http://localhost:9621/query",
description="The base URL for the LightRag server",
)
MODE: Literal["naive", "local", "global", "hybrid"] = Field(
default="hybrid",
description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
)
ONLY_NEED_CONTEXT: bool = Field(
default=False,
description="If True, only the context is needed from the LightRag response",
)
DEBUG_MODE: bool = Field(
default=False,
description="If True, debugging information will be emitted",
)
KEY: str = Field(
default="",
description="Optional Bearer Key for authentication",
)
MAX_ENTITIES: int = Field(
default=5,
description="Maximum number of entities to keep",
)
MAX_RELATIONSHIPS: int = Field(
default=5,
description="Maximum number of relationships to keep",
)
MAX_SOURCES: int = Field(
default=3,
description="Maximum number of sources to keep",
)
def __init__(self):
self.valves = self.Valves()
self.headers = {
"Content-Type": "application/json",
"User-Agent": "LightRag-Tool/1.0",
}
async def query_lightrag(
self,
query: str,
__event_emitter__: Callable[[dict], Any] = None,
) -> str:
"""
Query the LightRag server and retrieve information.
This function must be called before answering the user question
:params query: The query string to send to the LightRag server.
:return: The response from the LightRag server in Markdown format or raw response.
"""
self.status_emitter = StatusEventEmitter(__event_emitter__)
self.message_emitter = MessageEventEmitter(__event_emitter__)
lightrag_url = self.valves.LIGHTRAG_SERVER_URL
payload = {
"query": query,
"mode": str(self.valves.MODE),
"stream": False,
"only_need_context": self.valves.ONLY_NEED_CONTEXT,
}
await self.status_emitter.emit("Initializing Lightrag query..")
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"### Debug Mode Active\n\nDebugging information will be displayed.\n"
)
await self.message_emitter.emit(
"#### Payload Sent to LightRag Server\n```json\n"
+ json.dumps(payload, indent=4)
+ "\n```\n"
)
# Add Bearer Key to headers if provided
if self.valves.KEY:
self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
try:
await self.status_emitter.emit("Sending request to LightRag server")
response = requests.post(
lightrag_url, json=payload, headers=self.headers, timeout=120
)
response.raise_for_status()
data = response.json()
await self.status_emitter.emit(
status="complete",
description="LightRag query Succeeded",
done=True,
)
# Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
if self.valves.ONLY_NEED_CONTEXT:
try:
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"#### LightRag Server Response\n```json\n"
+ data["response"]
+ "\n```\n"
)
except Exception as ex:
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"#### Exception\n" + str(ex) + "\n"
)
return f"Exception: {ex}"
return data["response"]
else:
if self.valves.DEBUG_MODE:
await self.message_emitter.emit(
"#### LightRag Server Response\n```json\n"
+ data["response"]
+ "\n```\n"
)
await self.status_emitter.emit("Lightrag query success")
return data["response"]
except requests.exceptions.RequestException as e:
await self.status_emitter.emit(
status="error",
description=f"Error during LightRag query: {str(e)}",
done=True,
)
return json.dumps({"error": str(e)})
def extract_code_blocks(
self, text: str, return_remaining_text: bool = False
) -> Union[List[dict], Tuple[List[dict], str]]:
"""
This function extracts code blocks from a given text and optionally returns the text without code blocks.
Parameters:
text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
return_remaining_text (bool): If True, also returns the text with code blocks removed.
Returns:
Union[List[dict], Tuple[List[dict], str]]:
- If return_remaining_text is False: Returns only the list of code block dictionaries
- If return_remaining_text is True: Returns a tuple containing:
* List of code block dictionaries
* String containing the text with all code blocks removed
Each code block dictionary contains:
- 'index' (int): The index of the code block in the text
- 'file_name' (str): The name of the file extracted from the preceding line, if available
- 'content' (str): The content of the code block
- 'type' (str): The type of the code block
- 'is_complete' (bool): True if the block has a closing tag, False otherwise
"""
remaining = text
bloc_index = 0
first_index = 0
indices = []
text_without_blocks = text
# Find all code block delimiters
while len(remaining) > 0:
try:
index = remaining.index("```")
indices.append(index + first_index)
remaining = remaining[index + 3 :]
first_index += index + 3
bloc_index += 1
except Exception:
if bloc_index % 2 == 1:
index = len(remaining)
indices.append(index)
remaining = ""
code_blocks = []
is_start = True
# Process code blocks and build text without blocks if requested
if return_remaining_text:
text_parts = []
last_end = 0
for index, code_delimiter_position in enumerate(indices):
if is_start:
block_infos = {
"index": len(code_blocks),
"file_name": "",
"section": "",
"content": "",
"type": "",
"is_complete": False,
}
# Store text before code block if returning remaining text
if return_remaining_text:
text_parts.append(text[last_end:code_delimiter_position].strip())
# Check the preceding line for file name
preceding_text = text[:code_delimiter_position].strip().splitlines()
if preceding_text:
last_line = preceding_text[-1].strip()
if last_line.startswith("<file_name>") and last_line.endswith(
"</file_name>"
):
file_name = last_line[
len("<file_name>") : -len("</file_name>")
].strip()
block_infos["file_name"] = file_name
elif last_line.startswith("## filename:"):
file_name = last_line[len("## filename:") :].strip()
block_infos["file_name"] = file_name
if last_line.startswith("<section>") and last_line.endswith(
"</section>"
):
section = last_line[
len("<section>") : -len("</section>")
].strip()
block_infos["section"] = section
sub_text = text[code_delimiter_position + 3 :]
if len(sub_text) > 0:
try:
find_space = sub_text.index(" ")
except Exception:
find_space = int(1e10)
try:
find_return = sub_text.index("\n")
except Exception:
find_return = int(1e10)
next_index = min(find_return, find_space)
if "{" in sub_text[:next_index]:
next_index = 0
start_pos = next_index
if code_delimiter_position + 3 < len(text) and text[
code_delimiter_position + 3
] in ["\n", " ", "\t"]:
block_infos["type"] = "language-specific"
else:
block_infos["type"] = sub_text[:next_index]
if index + 1 < len(indices):
next_pos = indices[index + 1] - code_delimiter_position
if (
next_pos - 3 < len(sub_text)
and sub_text[next_pos - 3] == "`"
):
block_infos["content"] = sub_text[
start_pos : next_pos - 3
].strip()
block_infos["is_complete"] = True
else:
block_infos["content"] = sub_text[
start_pos:next_pos
].strip()
block_infos["is_complete"] = False
if return_remaining_text:
last_end = indices[index + 1] + 3
else:
block_infos["content"] = sub_text[start_pos:].strip()
block_infos["is_complete"] = False
if return_remaining_text:
last_end = len(text)
code_blocks.append(block_infos)
is_start = False
else:
is_start = True
if return_remaining_text:
# Add any remaining text after the last code block
if last_end < len(text):
text_parts.append(text[last_end:].strip())
# Join all non-code parts with newlines
text_without_blocks = "\n".join(filter(None, text_parts))
return code_blocks, text_without_blocks
return code_blocks
def clean(self, csv_content: str):
lines = csv_content.splitlines()
if lines:
# Remove spaces around headers and ensure no spaces between commas
header = ",".join([col.strip() for col in lines[0].split(",")])
lines[0] = header # Replace the first line with the cleaned header
csv_content = "\n".join(lines)
return csv_content

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
Literal,
Optional,
TypedDict,
TypeVar,
Union,
)
import numpy as np
@@ -69,7 +69,7 @@ class QueryParam:
ll_keywords: list[str] = field(default_factory=list)
"""List of low-level keywords to refine retrieval focus."""
conversation_history: list[dict[str, Any]] = field(default_factory=list)
conversation_history: list[dict[str, str]] = field(default_factory=list)
"""Stores past conversation history to maintain context.
Format: [{"role": "user/assistant", "content": "message"}].
"""
@@ -83,19 +83,15 @@ class StorageNameSpace:
namespace: str
global_config: dict[str, Any]
async def index_done_callback(self):
async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing"""
pass
async def query_done_callback(self):
"""Commit the storage operations after querying"""
pass
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
raise NotImplementedError
@@ -106,12 +102,20 @@ class BaseVectorStorage(StorageNameSpace):
"""
raise NotImplementedError
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass
class BaseKVStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
embedding_func: EmbeddingFunc | None = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async def get_by_id(self, id: str) -> dict[str, Any] | None:
raise NotImplementedError
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -130,50 +134,75 @@ class BaseKVStorage(StorageNameSpace):
@dataclass
class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None
embedding_func: EmbeddingFunc | None = None
"""Check if a node exists in the graph."""
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
"""Check if an edge exists in the graph."""
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
"""Get the degree of a node."""
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
"""Get the degree of an edge."""
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
"""Get a node by its id."""
async def get_node(self, node_id: str) -> dict[str, str] | None:
raise NotImplementedError
"""Get an edge by its source and target node ids."""
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
) -> dict[str, str] | None:
raise NotImplementedError
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
"""Get all edges connected to a node."""
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""Upsert a node into the graph."""
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError
"""Upsert an edge into the graph."""
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
) -> None:
raise NotImplementedError
async def delete_node(self, node_id: str):
"""Delete a node from the graph."""
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""Embed nodes using an algorithm."""
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
"""Get all labels in the graph."""
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
"""Get a knowledge graph of a node."""
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
@@ -205,9 +234,9 @@ class DocProcessingStatus:
"""ISO format timestamp when document was created"""
updated_at: str
"""ISO format timestamp when document was last updated"""
chunks_count: Optional[int] = None
chunks_count: int | None = None
"""Number of chunks after splitting, used for processing"""
error: Optional[str] = None
error: str | None = None
"""Error message if failed"""
metadata: dict[str, Any] = field(default_factory=dict)
"""Additional metadata"""

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import httpx
from typing import Literal

View File

@@ -67,7 +67,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
config.get(
"auth_header_name", "X-Chroma-Token"
): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
@@ -154,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
results = self._collection.query(
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
query_embeddings=embedding.tolist()
if not isinstance(embedding, list)
else embedding,
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)

View File

@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
await self.delete([entity_id])
async def delete_entity_relation(self, entity_name: str):
async def delete_entity_relation(self, entity_name: str) -> None:
"""
Delete relations for a given entity by scanning metadata.
"""

View File

@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str):
async def delete_entity_relation(self, entity_name: str) -> None:
try:
relations = [
dp

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import asyncio
import os
import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, Callable, Optional, Type, Union, cast
from typing import Any, AsyncIterator, Callable, Iterator, cast
from .base import (
BaseGraphStorage,
@@ -92,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = {
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
@@ -179,7 +181,7 @@ STORAGES = {
}
def lazy_external_import(module_name: str, class_name: str):
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
@@ -188,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args, **kwargs):
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
@@ -305,7 +307,7 @@ class LightRAG:
- random_seed: Seed value for reproducibility.
"""
embedding_func: EmbeddingFunc = None
embedding_func: EmbeddingFunc | None = None
"""Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32
@@ -315,7 +317,7 @@ class LightRAG:
"""Maximum number of concurrent embedding function calls."""
# LLM Configuration
llm_model_func: callable = None
llm_model_func: Callable[..., object] | None = None
"""Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -345,10 +347,8 @@ class LightRAG:
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
# extension
addon_params: dict[str, Any] = field(default_factory=dict)
"""Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
)
@@ -357,7 +357,7 @@ class LightRAG:
chunking_func: Callable[
[
str,
Optional[str],
str | None,
bool,
int,
int,
@@ -446,77 +446,74 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
# Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
self._get_storage_class(self.kv_storage)
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
) # type: ignore
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
) # type: ignore
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
)
self.key_string_value_json_storage_cls = partial(
) # type: ignore
self.key_string_value_json_storage_cls = partial( # type: ignore
self.key_string_value_json_storage_cls, global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls = partial( # type: ignore
self.vector_db_storage_cls, global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls = partial( # type: ignore
self.graph_storage_cls, global_config=global_config
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.llm_response_cache = self.key_string_value_json_storage_cls(
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
embedding_func=self.embedding_func,
)
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
),
embedding_func=self.embedding_func,
)
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
),
embedding_func=self.embedding_func,
)
self.entities_vdb = self.vector_db_storage_cls(
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = self.vector_db_storage_cls(
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
),
@@ -530,13 +527,12 @@ class LightRAG:
embedding_func=None,
)
# What's for, Is this nessisary ?
if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
@@ -545,7 +541,7 @@ class LightRAG:
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func,
self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
@@ -562,68 +558,45 @@ class LightRAG:
node_label=nodel_label, max_depth=max_depth
)
def _get_storage_class(self, storage_name: str) -> dict:
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
def set_storage_client(self, db_client):
# Deprecated, seting correct value to *_storage of LightRAG insteaded
# Inject db to storage implementation (only tested on Oracle Database)
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client
storage.db = db_client
def insert(
self,
string_or_strings: Union[str, list[str]],
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Sync Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
self.ainsert(string_or_strings, split_by_character, split_by_character_only)
self.ainsert(input, split_by_character, split_by_character_only)
)
async def ainsert(
self,
string_or_strings: Union[str, list[str]],
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Async Insert documents with checkpoint support
Args:
string_or_strings: Single document string or list of document strings
input: Single document string or list of document strings
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
chunk_size, split the sub chunk by token size.
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
await self.apipeline_enqueue_documents(string_or_strings)
await self.apipeline_enqueue_documents(input)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
@@ -680,7 +653,7 @@ class LightRAG:
if update_storage:
await self._insert_done()
async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
async def apipeline_enqueue_documents(self, input: str | list[str]):
"""
Pipeline for Processing Documents
@@ -689,11 +662,11 @@ class LightRAG:
3. Filter out already processed documents
4. Enqueue document in status
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
if isinstance(input, str):
input = [input]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
unique_contents = list(set(doc.strip() for doc in input))
# 2. Generate document IDs and initial status
new_docs: dict[str, Any] = {
@@ -860,32 +833,32 @@ class LightRAG:
raise e
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
tasks = [
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]
if storage_inst is not None
]
await asyncio.gather(*tasks)
def insert_custom_kg(self, custom_kg: dict):
def insert_custom_kg(self, custom_kg: dict[str, Any]):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
async def ainsert_custom_kg(self, custom_kg: dict):
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
update_storage = False
try:
# Insert chunks into vector storage
all_chunks_data = {}
chunk_to_source_map = {}
for chunk_data in custom_kg.get("chunks", []):
all_chunks_data: dict[str, dict[str, str]] = {}
chunk_to_source_map: dict[str, str] = {}
for chunk_data in custom_kg.get("chunks", {}):
chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -895,13 +868,13 @@ class LightRAG:
chunk_to_source_map[source_id] = chunk_id
update_storage = True
if self.chunks_vdb is not None and all_chunks_data:
if all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data)
if self.text_chunks is not None and all_chunks_data:
if all_chunks_data:
await self.text_chunks.upsert(all_chunks_data)
# Insert entities into knowledge graph
all_entities_data = []
all_entities_data: list[dict[str, str]] = []
for entity_data in custom_kg.get("entities", []):
entity_name = f'"{entity_data["entity_name"].upper()}"'
entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -917,7 +890,7 @@ class LightRAG:
)
# Prepare node data
node_data = {
node_data: dict[str, str] = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
@@ -931,7 +904,7 @@ class LightRAG:
update_storage = True
# Insert relationships into knowledge graph
all_relationships_data = []
all_relationships_data: list[dict[str, str]] = []
for relationship_data in custom_kg.get("relationships", []):
src_id = f'"{relationship_data["src_id"].upper()}"'
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -973,7 +946,7 @@ class LightRAG:
"source_id": source_id,
},
)
edge_data = {
edge_data: dict[str, str] = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
@@ -983,41 +956,68 @@ class LightRAG:
update_storage = True
# Insert entities into vector storage if needed
if self.entities_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
await self.entities_vdb.upsert(data_for_vdb)
for dp in all_entities_data
}
await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
if self.relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
await self.relationships_vdb.upsert(data_for_vdb)
for dp in all_relationships_data
}
await self.relationships_vdb.upsert(data_for_vdb)
finally:
if update_storage:
await self._insert_done()
def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
def query(
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
) -> str | Iterator[str]:
"""
Perform a sync query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, prompt, param))
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
async def aquery(
self, query: str, prompt: str = "", param: QueryParam = QueryParam()
):
self,
query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None,
) -> str | AsyncIterator[str]:
"""
Perform a async query.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
"""
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query,
@@ -1097,7 +1097,7 @@ class LightRAG:
async def aquery_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
):
) -> str | AsyncIterator[str]:
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
@@ -1120,8 +1120,8 @@ class LightRAG:
),
)
param.hl_keywords = (hl_keywords,)
param.ll_keywords = (ll_keywords,)
param.hl_keywords = hl_keywords
param.ll_keywords = ll_keywords
# ---------------------
# STEP 2: Final Query Logic
@@ -1149,7 +1149,7 @@ class LightRAG:
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_funcne,
embedding_func=self.embedding_func,
),
)
elif param.mode == "naive":
@@ -1198,12 +1198,7 @@ class LightRAG:
return response
async def _query_done(self):
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await self.llm_response_cache.index_done_callback()
def delete_by_entity(self, entity_name: str):
loop = always_get_an_event_loop()
@@ -1225,16 +1220,16 @@ class LightRAG:
logger.error(f"Error while deleting entity '{entity_name}': {e}")
async def _delete_by_entity_done(self):
tasks = []
for storage_inst in [
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
self.entities_vdb,
self.relationships_vdb,
self.chunk_entity_relation_graph,
]
]
)
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
"""Get summary of document content
@@ -1259,7 +1254,7 @@ class LightRAG:
"""
return await self.doc_status.get_status_counts()
async def adelete_by_doc_id(self, doc_id: str):
async def adelete_by_doc_id(self, doc_id: str) -> None:
"""Delete a document and all its related data
Args:
@@ -1276,6 +1271,9 @@ class LightRAG:
# 2. Get all related chunks
chunks = await self.text_chunks.get_by_id(doc_id)
if not chunks:
return
chunk_ids = list(chunks.keys())
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
@@ -1446,13 +1444,9 @@ class LightRAG:
except Exception as e:
logger.error(f"Error while deleting document {doc_id}: {e}")
def delete_by_doc_id(self, doc_id: str):
"""Synchronous version of adelete"""
return asyncio.run(self.adelete_by_doc_id(doc_id))
async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False
):
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity
Args:
@@ -1472,7 +1466,7 @@ class LightRAG:
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
source_id = node_data.get("source_id") if node_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
@@ -1486,21 +1480,6 @@ class LightRAG:
return result
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
"""Synchronous version of getting entity information
Args:
entity_name: Entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
tracemalloc.start()
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
finally:
tracemalloc.stop()
async def get_relation_info(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
@@ -1528,7 +1507,7 @@ class LightRAG:
)
source_id = edge_data.get("source_id") if edge_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,
@@ -1542,23 +1521,3 @@ class LightRAG:
result["vector_data"] = vector_data[0] if vector_data else None
return result
def get_relation_info_sync(
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
):
"""Synchronous version of getting relationship information
Args:
src_entity: Source entity name (no need for quotes)
tgt_entity: Target entity name (no need for quotes)
include_vector_data: Whether to include data from the vector database
"""
try:
import tracemalloc
tracemalloc.start()
return asyncio.run(
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
)
finally:
tracemalloc.stop()

View File

@@ -1,4 +1,6 @@
from typing import List, Dict, Callable, Any
from __future__ import annotations
from typing import Callable, Any
from pydantic import BaseModel, Field
@@ -23,7 +25,7 @@ class Model(BaseModel):
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: Dict[str, Any] = Field(
kwargs: dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
@@ -57,7 +59,7 @@ class MultiModel:
```
"""
def __init__(self, models: List[Model]):
def __init__(self, models: list[Model]):
self._models = models
self._current_model = 0
@@ -66,7 +68,11 @@ class MultiModel:
return self._models[self._current_model]
async def llm_model_func(
self, prompt, system_prompt=None, history_messages=[], **kwargs
self,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] = [],
**kwargs: Any,
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None)

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Iterable

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import json
import re
from tqdm.asyncio import tqdm as tqdm_async
from typing import Any, Union
from typing import Any, AsyncIterator
from collections import Counter, defaultdict
from .utils import (
logger,
@@ -36,7 +38,7 @@ import time
def chunking_by_token_size(
content: str,
split_by_character: Union[str, None] = None,
split_by_character: str | None = None,
split_by_character_only: bool = False,
overlap_token_size: int = 128,
max_token_size: int = 1024,
@@ -335,9 +337,9 @@ async def extract_entities(
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
llm_response_cache: BaseKVStorage = None,
) -> Union[BaseGraphStorage, None]:
global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None,
) -> BaseGraphStorage | None:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
@@ -603,15 +605,15 @@ async def extract_entities(
async def kg_query(
query,
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
prompt: str = "",
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
prompt: str | None = None,
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
@@ -721,8 +723,8 @@ async def kg_query(
async def extract_keywords_only(
text: str,
param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -818,9 +820,9 @@ async def mix_kg_vector_query(
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str | AsyncIterator[str]:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1539,13 +1541,13 @@ def combine_contexts(entities, relationships, sources):
async def naive_query(
query,
query: str,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
):
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
@@ -1646,9 +1648,9 @@ async def kg_query_with_keywords(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str | AsyncIterator[str]:
"""
Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {}

View File

@@ -1,16 +1,18 @@
from __future__ import annotations
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
from typing import Any, Optional
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
high_level_keywords: list[str]
low_level_keywords: list[str]
class KnowledgeGraphNode(BaseModel):
id: str
labels: List[str]
properties: Dict[str, Any] # anything else goes here
labels: list[str]
properties: dict[str, Any] # anything else goes here
class KnowledgeGraphEdge(BaseModel):
@@ -18,9 +20,9 @@ class KnowledgeGraphEdge(BaseModel):
type: Optional[str]
source: str # id of source node
target: str # id of target node
properties: Dict[str, Any] # anything else goes here
properties: dict[str, Any] # anything else goes here
class KnowledgeGraph(BaseModel):
nodes: List[KnowledgeGraphNode] = []
edges: List[KnowledgeGraphEdge] = []
nodes: list[KnowledgeGraphNode] = []
edges: list[KnowledgeGraphEdge] = []

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import html
import io
@@ -9,7 +11,7 @@ import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
from typing import Any, Union, List, Optional
from typing import Any, Callable
import xml.etree.ElementTree as ET
import bs4
@@ -67,12 +69,12 @@ class EmbeddingFunc:
@dataclass
class ReasoningResponse:
reasoning_content: str
reasoning_content: str | None
response_content: str
tag: str
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
def locate_json_string_body_from_string(content: str) -> str | None:
"""Locate the JSON string body from a string"""
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
raise e from None
def compute_args_hash(*args, cache_type: str = None) -> str:
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
return hashlib.md5(args_str.encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""):
def compute_mdhash_id(content: str, prefix: str = "") -> str:
"""
Compute a unique ID for a given content string.
The ID is a combination of the given prefix and the MD5 hash of the content string.
"""
return prefix + md5(content.encode()).hexdigest()
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value):
def is_float_regex(value: str) -> bool:
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
) -> list[int]:
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data
def list_of_list_to_csv(data: List[List[str]]) -> str:
def list_of_list_to_csv(data: list[list[str]]) -> str:
output = io.StringIO()
writer = csv.writer(
output,
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
return output.getvalue()
def csv_string_to_list(csv_string: str) -> List[List[str]]:
def csv_string_to_list(csv_string: str) -> list[list[str]]:
# Clean the string by removing NUL characters
cleaned_string = csv_string.replace("\0", "")
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
return None
def process_combine_contexts(hl, ll):
def process_combine_contexts(hl: str, ll: str):
header = None
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
@@ -375,7 +384,7 @@ async def get_best_cached_response(
llm_func=None,
original_prompt=None,
cache_type=None,
) -> Union[str, None]:
) -> str | None:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
@@ -570,9 +579,9 @@ class CacheData:
args_hash: str
content: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
quantized: np.ndarray | None = None
min_val: float | None = None
max_val: float | None = None
mode: str = "default"
cache_type: str = "query"
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
return False
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
def get_conversation_turns(
conversation_history: list[dict[str, Any]], num_turns: int
) -> str:
"""
Process conversation history to get the specified number of complete turns.
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
Formatted string of the conversation history
"""
# Group messages into turns
turns = []
messages = []
turns: list[list[dict[str, Any]]] = []
messages: list[dict[str, Any]] = []
# First, filter out keyword extraction messages
for msg in conversation_history:
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
turns = turns[-num_turns:]
# Format the turns into a string
formatted_turns = []
formatted_turns: list[str] = []
for turn in turns:
formatted_turns.extend(
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]