Some enhancements:

- Enable the llm_cache storage to support get_by_mode_and_id, to improve the performance for using real KV server
- Provide an option for the developers to cache the LLM response when extracting entities for a document. Solving the paint point that sometimes the process failed, the processed chunks we need to call LLM again, money and time wasted. With the new option (by default not enabled) enabling, we can cache that result, can significantly save the time and money for beginners.
This commit is contained in:
Samuel Chan
2025-01-06 12:50:05 +08:00
parent 6c1b669f0f
commit 6ae27d8f06
7 changed files with 182 additions and 70 deletions

View File

@@ -26,6 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div>
## 🎉 News
- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage).
- [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
- [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
@@ -356,6 +357,11 @@ rag = LightRAG(
```
see test_neo4j.py for a working example.
### Using PostgreSQL for Storage
For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
### Insert Custom KG
```python
@@ -602,33 +608,34 @@ if __name__ == "__main__":
### LightRAG init parameters
| **Parameter** | **Type** | **Explanation** | **Default** |
| --- | --- | --- | --- |
| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` |
| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` |
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
| **Parameter** | **Type** | **Explanation** | **Default** |
|----------------------------------------------| --- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------|
| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` |
| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` |
| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **enable\_llm\_cache\_for\_entity\_extract** | `bool` | If `TRUE`, stores LLM results in cache for entity extraction; Good for beginners to debug your application | `FALSE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
### Error Handling
<details>
@@ -1206,6 +1213,7 @@ curl "http://localhost:9621/health"
```
## Development
Contribute to the project: [Guide](contributor-readme.MD)
### Running in Development Mode

12
contributor-readme.MD Normal file
View File

@@ -0,0 +1,12 @@
# Handy Tips for Developers Who Want to Contribute to the Project
## Pre-commit Hooks
Please ensure you have run pre-commit hooks before committing your changes.
### Guides
1. **Installing Pre-commit Hooks**:
- Install pre-commit using pip: `pip install pre-commit`
- Initialize pre-commit in your repository: `pre-commit install`
- Run pre-commit hooks: `pre-commit run --all-files`
2. **Pre-commit Hooks Configuration**:
- Create a `.pre-commit-config.yaml` file in the root of your repository.
- Add your hooks to the `.pre-commit-config.yaml`file.

View File

@@ -43,6 +43,7 @@ async def main():
llm_model_name="glm-4-flashx",
llm_model_max_async=4,
llm_model_max_token_size=32768,
enable_llm_cache_for_entity_extract=True,
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,

View File

@@ -151,7 +151,10 @@ class PostgreSQLDB:
try:
await conn.execute('SET search_path = ag_catalog, "$user", public')
await conn.execute(f"""select create_graph('{graph_name}')""")
except asyncpg.exceptions.InvalidSchemaNameError:
except (
asyncpg.exceptions.InvalidSchemaNameError,
asyncpg.exceptions.UniqueViolationError,
):
pass
@@ -160,7 +163,6 @@ class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
################ QUERY METHODS ################
@@ -181,6 +183,19 @@ class PGKVStorage(BaseKVStorage):
else:
return None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, mode: mode, "id": id}
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(sql, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res
else:
return None
# Query by id
async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
"""Get doc_chunks data by id"""
@@ -229,33 +244,30 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################
async def upsert(self, data: Dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if self.namespace == "text_chunks":
pass
elif self.namespace == "full_docs":
for k, v in self._data.items():
for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
data = {
_data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
await self.db.execute(upsert_sql, data)
await self.db.execute(upsert_sql, _data)
elif self.namespace == "llm_response_cache":
for mode, items in self._data.items():
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
data = {
_data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
"return": v["return"],
"return_value": v["return"],
"mode": mode,
}
await self.db.execute(upsert_sql, data)
return left_data
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
@@ -977,9 +989,6 @@ class PGGraphStorage(BaseGraphStorage):
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
logger.info(
f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
)
query = """MATCH (source:`{src_label}`)
WITH source
@@ -1028,8 +1037,8 @@ TABLES = {
doc_name VARCHAR(1024),
content TEXT,
meta JSONB,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1042,8 +1051,8 @@ TABLES = {
tokens INTEGER,
content TEXT,
content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1054,8 +1063,8 @@ TABLES = {
entity_name VARCHAR(255),
content TEXT,
content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1067,8 +1076,8 @@ TABLES = {
target_id VARCHAR(256),
content TEXT,
content_vector VECTOR,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1078,10 +1087,10 @@ TABLES = {
id varchar(255) NOT NULL,
mode varchar(32) NOT NULL,
original_prompt TEXT,
return TEXT,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
return_value TEXT,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
)"""
},
"LIGHTRAG_DOC_STATUS": {
@@ -1109,9 +1118,12 @@ SQL_TEMPLATES = {
chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
""",
@@ -1119,22 +1131,22 @@ SQL_TEMPLATES = {
chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
VALUES ($1, $2, $3)
ON CONFLICT (workspace,id) DO UPDATE
SET content = $2, updatetime = CURRENT_TIMESTAMP
SET content = $2, update_time = CURRENT_TIMESTAMP
""",
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,"return",mode)
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (workspace,id) DO UPDATE
ON CONFLICT (workspace,mode,id) DO UPDATE
SET original_prompt = EXCLUDED.original_prompt,
"return"=EXCLUDED."return",
return_value=EXCLUDED.return_value,
mode=EXCLUDED.mode,
updatetime = CURRENT_TIMESTAMP
update_time = CURRENT_TIMESTAMP
""",
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector)
@@ -1145,7 +1157,7 @@ SQL_TEMPLATES = {
full_doc_id=EXCLUDED.full_doc_id,
content = EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
updatetime = CURRENT_TIMESTAMP
update_time = CURRENT_TIMESTAMP
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
VALUES ($1, $2, $3, $4, $5)
@@ -1153,7 +1165,7 @@ SQL_TEMPLATES = {
SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
updatetime=CURRENT_TIMESTAMP
update_time=CURRENT_TIMESTAMP
""",
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
target_id, content, content_vector)
@@ -1162,7 +1174,7 @@ SQL_TEMPLATES = {
SET source_id=EXCLUDED.source_id,
target_id=EXCLUDED.target_id,
content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector, updatetime = CURRENT_TIMESTAMP
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
""",
# SQL for VectorStorage
"entities": """SELECT entity_name FROM

View File

@@ -176,6 +176,8 @@ class LightRAG:
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
enable_llm_cache: bool = True
# Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
enable_llm_cache_for_entity_extract: bool = False
# extension
addon_params: dict = field(default_factory=dict)
@@ -402,6 +404,7 @@ class LightRAG:
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)

View File

@@ -253,9 +253,13 @@ async def extract_entities(
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
llm_response_cache: BaseKVStorage = None,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
"enable_llm_cache_for_entity_extract"
]
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
@@ -300,6 +304,52 @@ async def extract_entities(
already_entities = 0
already_relations = 0
async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False
if (
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
):
new_config = global_config.copy()
new_config["embedding_cache_config"] = None
new_config["enable_llm_cache"] = True
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages)
_prompt = history + "\n" + input_text
else:
_prompt = input_text
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, arg_hash, _prompt, "default"
)
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
return cached_return
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
)
return res
if history_messages:
return await use_llm_func(input_text, history_messages=history_messages)
else:
return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
@@ -310,17 +360,19 @@ async def extract_entities(
**context_base, input_text="{input_text}"
).format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
final_result = await _user_llm_func_with_cache(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
glean_result = await _user_llm_func_with_cache(
continue_prompt, history_messages=history
)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
if_loop_result: str = await use_llm_func(
if_loop_result: str = await _user_llm_func_with_cache(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()

View File

@@ -454,7 +454,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
# For naive mode, only use simple cache matching
if mode == "naive":
mode_cache = await hashing_kv.get_by_id(mode) or {}
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
@@ -488,7 +491,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
@@ -510,7 +516,13 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = (
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
or {}
)
else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
@@ -543,3 +555,15 @@ def safe_unicode_decode(content):
)
return decoded_content
def exists_func(obj, func_name: str) -> bool:
"""Check if a function exists in an object or not.
:param obj:
:param func_name:
:return: True / False
"""
if callable(getattr(obj, func_name, None)):
return True
else:
return False