Merge pull request #545 from ShanGor/main

Enhance the llm_cache_kv_store, enable the llm_cache for entity extraction and revise readme
This commit is contained in:
zrguo
2025-01-06 15:24:34 +08:00
committed by GitHub
7 changed files with 182 additions and 70 deletions

View File

@@ -26,6 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div>
## 🎉 News
- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage).
- [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
- [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
@@ -356,6 +357,11 @@ rag = LightRAG(
```
see test_neo4j.py for a working example.
### Using PostgreSQL for Storage
For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
### Insert Custom KG
```python
@@ -602,33 +608,34 @@ if __name__ == "__main__":
### LightRAG init parameters
| **Parameter** | **Type** | **Explanation** | **Default** |
| --- | --- | --- | --- |
| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` |
| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` |
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
| **Parameter** | **Type** | **Explanation** | **Default** |
|----------------------------------------------| --- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------|
| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` |
| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` |
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **enable\_llm\_cache\_for\_entity\_extract** | `bool` | If `TRUE`, stores LLM results in cache for entity extraction; Good for beginners to debug your application | `FALSE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
### Error Handling
<details>
@@ -1206,6 +1213,7 @@ curl "http://localhost:9621/health"
```
## Development
Contribute to the project: [Guide](contributor-readme.MD)
### Running in Development Mode

12
contributor-readme.MD Normal file
View File

@@ -0,0 +1,12 @@
# Handy Tips for Developers Who Want to Contribute to the Project
## Pre-commit Hooks
Please ensure you have run pre-commit hooks before committing your changes.
### Guides
1. **Installing Pre-commit Hooks**:
- Install pre-commit using pip: `pip install pre-commit`
- Initialize pre-commit in your repository: `pre-commit install`
- Run pre-commit hooks: `pre-commit run --all-files`
2. **Pre-commit Hooks Configuration**:
- Create a `.pre-commit-config.yaml` file in the root of your repository.
- Add your hooks to the `.pre-commit-config.yaml`file.

View File

@@ -43,6 +43,7 @@ async def main():
llm_model_name="glm-4-flashx",
llm_model_max_async=4,
llm_model_max_token_size=32768,
enable_llm_cache_for_entity_extract=True,
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,

View File

@@ -151,7 +151,10 @@ class PostgreSQLDB:
try:
await conn.execute('SET search_path = ag_catalog, "$user", public')
await conn.execute(f"""select create_graph('{graph_name}')""")
except asyncpg.exceptions.InvalidSchemaNameError:
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
@@ -160,7 +163,6 @@ class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
################ QUERY METHODS ################
@@ -181,6 +183,19 @@ class PGKVStorage(BaseKVStorage):
else:
return None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, mode: mode, "id": id}
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res
else:
return None
# Query by id
async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
"""Get doc_chunks data by id"""
@@ -229,33 +244,30 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################
async def upsert(self, data: Dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if self.namespace == "text_chunks":
pass
elif self.namespace == "full_docs":
for k, v in self._data.items():
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
data = {
_data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
await self.db.execute(upsert_sql, data)
await self.db.execute(upsert_sql, _data)
elif self.namespace == "llm_response_cache":
for mode, items in self._data.items():
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
data = {
_data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
"return": v["return"],
"return_value": v["return"],
"mode": mode,
}
await self.db.execute(upsert_sql, data)
return left_data
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
@@ -977,9 +989,6 @@ class PGGraphStorage(BaseGraphStorage):
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
logger.info(
f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
)
query = """MATCH (source:`{src_label}`)
WITH source
@@ -1028,8 +1037,8 @@ TABLES = {
doc_name VARCHAR(1024),
content TEXT,
meta JSONB,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1042,8 +1051,8 @@ TABLES = {
tokens INTEGER,
content TEXT,
content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1054,8 +1063,8 @@ TABLES = {
entity_name VARCHAR(255),
content TEXT,
content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1067,8 +1076,8 @@ TABLES = {
target_id VARCHAR(256),
content TEXT,
content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1078,10 +1087,10 @@ TABLES = {
id varchar(255) NOT NULL,
mode varchar(32) NOT NULL,
original_prompt TEXT,
return TEXT,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
return_value TEXT,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
)"""
},
"LIGHTRAG_DOC_STATUS": {
@@ -1109,9 +1118,12 @@ SQL_TEMPLATES = {
chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
""",
@@ -1119,22 +1131,22 @@ SQL_TEMPLATES = {
chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
VALUES ($1, $2, $3)
ON CONFLICT (workspace,id) DO UPDATE
SET content = $2, updatetime = CURRENT_TIMESTAMP
SET content = $2, update_time = CURRENT_TIMESTAMP
""",
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,"return",mode)
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (workspace,id) DO UPDATE
ON CONFLICT (workspace,mode,id) DO UPDATE
SET original_prompt = EXCLUDED.original_prompt,
"return"=EXCLUDED."return",
return_value=EXCLUDED.return_value,
mode=EXCLUDED.mode,
updatetime = CURRENT_TIMESTAMP
update_time = CURRENT_TIMESTAMP
""",
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector)
@@ -1145,7 +1157,7 @@ SQL_TEMPLATES = {
full_doc_id=EXCLUDED.full_doc_id,
content = EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
updatetime = CURRENT_TIMESTAMP
update_time = CURRENT_TIMESTAMP
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
VALUES ($1, $2, $3, $4, $5)
@@ -1153,7 +1165,7 @@ SQL_TEMPLATES = {
SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
updatetime=CURRENT_TIMESTAMP
update_time=CURRENT_TIMESTAMP
""",
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
target_id, content, content_vector)
@@ -1162,7 +1174,7 @@ SQL_TEMPLATES = {
SET source_id=EXCLUDED.source_id,
target_id=EXCLUDED.target_id,
content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector, updatetime = CURRENT_TIMESTAMP
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
""",
# SQL for VectorStorage
"entities": """SELECT entity_name FROM

View File

@@ -176,6 +176,8 @@ class LightRAG:
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
enable_llm_cache: bool = True
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
enable_llm_cache_for_entity_extract: bool = False
# extension
addon_params: dict = field(default_factory=dict)
@@ -402,6 +404,7 @@ class LightRAG:
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)

View File

@@ -253,9 +253,13 @@ async def extract_entities(
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
llm_response_cache: BaseKVStorage = None,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
"enable_llm_cache_for_entity_extract"
]
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
@@ -300,6 +304,52 @@ async def extract_entities(
already_entities = 0
already_relations = 0
async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False
if (
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
):
new_config = global_config.copy()
new_config["embedding_cache_config"] = None
new_config["enable_llm_cache"] = True
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, arg_hash, _prompt, "default"
)
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
return cached_return
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
)
return res
if history_messages:
return await use_llm_func(input_text, history_messages=history_messages)
else:
return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
@@ -310,17 +360,19 @@ async def extract_entities(
**context_base, input_text="{input_text}"
).format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
final_result = await _user_llm_func_with_cache(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
glean_result = await _user_llm_func_with_cache(
continue_prompt, history_messages=history
)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
if_loop_result: str = await use_llm_func(
if_loop_result: str = await _user_llm_func_with_cache(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()

View File

@@ -454,7 +454,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
# For naive mode, only use simple cache matching
if mode == "naive":
mode_cache = await hashing_kv.get_by_id(mode) or {}
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
@@ -488,7 +491,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
@@ -510,7 +516,13 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = (
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
or {}
)
else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
@@ -543,3 +555,15 @@ def safe_unicode_decode(content):
)
return decoded_content
def exists_func(obj, func_name: str) -> bool:
"""Check if a function exists in an object or not.
:param obj:
:param func_name:
:return: True / False
"""
if callable(getattr(obj, func_name, None)):
return True
else:
return False