diff --git a/README.md b/README.md
index 54a84323..ed2a7789 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
## 🎉 News
+- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage).
- [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
- [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
@@ -356,6 +357,11 @@ rag = LightRAG(
```
see test_neo4j.py for a working example.
+### Using PostgreSQL for Storage
+For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
+* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
+* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
+
### Insert Custom KG
```python
@@ -602,33 +608,34 @@ if __name__ == "__main__":
### LightRAG init parameters
-| **Parameter** | **Type** | **Explanation** | **Default** |
-| --- | --- | --- | --- |
-| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
-| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
-| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
-| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
-| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
-| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
-| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
-| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
-| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
-| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
-| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
-| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
-| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
-| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
-| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
-| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
-| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
-| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` |
-| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` |
-| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
-| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
-| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
-| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
-| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
-| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
+| **Parameter** | **Type** | **Explanation** | **Default** |
+|----------------------------------------------| --- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------|
+| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
+| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
+| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
+| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
+| **log\_level** | | Log level for application runtime | `logging.DEBUG` |
+| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
+| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
+| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
+| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
+| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
+| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
+| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
+| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
+| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
+| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
+| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
+| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
+| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` |
+| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` |
+| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | |
+| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | |
+| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
+| **enable\_llm\_cache\_for\_entity\_extract** | `bool` | If `TRUE`, stores LLM results in cache for entity extraction; Good for beginners to debug your application | `FALSE` |
+| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` |
+| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
+| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
### Error Handling
@@ -1206,6 +1213,7 @@ curl "http://localhost:9621/health"
```
## Development
+Contribute to the project: [Guide](contributor-readme.MD)
### Running in Development Mode
diff --git a/contributor-readme.MD b/contributor-readme.MD
new file mode 100644
index 00000000..2168d469
--- /dev/null
+++ b/contributor-readme.MD
@@ -0,0 +1,12 @@
+# Handy Tips for Developers Who Want to Contribute to the Project
+## Pre-commit Hooks
+Please ensure you have run pre-commit hooks before committing your changes.
+### Guides
+1. **Installing Pre-commit Hooks**:
+ - Install pre-commit using pip: `pip install pre-commit`
+ - Initialize pre-commit in your repository: `pre-commit install`
+ - Run pre-commit hooks: `pre-commit run --all-files`
+
+2. **Pre-commit Hooks Configuration**:
+ - Create a `.pre-commit-config.yaml` file in the root of your repository.
+ - Add your hooks to the `.pre-commit-config.yaml`file.
diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py
index f28480df..d0461d84 100644
--- a/examples/lightrag_zhipu_postgres_demo.py
+++ b/examples/lightrag_zhipu_postgres_demo.py
@@ -43,6 +43,7 @@ async def main():
llm_model_name="glm-4-flashx",
llm_model_max_async=4,
llm_model_max_token_size=32768,
+ enable_llm_cache_for_entity_extract=True,
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py
index 704fa476..033d63d6 100644
--- a/lightrag/kg/postgres_impl.py
+++ b/lightrag/kg/postgres_impl.py
@@ -151,7 +151,10 @@ class PostgreSQLDB:
try:
await conn.execute('SET search_path = ag_catalog, "$user", public')
await conn.execute(f"""select create_graph('{graph_name}')""")
- except asyncpg.exceptions.InvalidSchemaNameError:
+ except (
+ asyncpg.exceptions.InvalidSchemaNameError,
+ asyncpg.exceptions.UniqueViolationError,
+ ):
pass
@@ -160,7 +163,6 @@ class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None
def __post_init__(self):
- self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
################ QUERY METHODS ################
@@ -181,6 +183,19 @@ class PGKVStorage(BaseKVStorage):
else:
return None
+ async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
+ """Specifically for llm_response_cache."""
+ sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
+ params = {"workspace": self.db.workspace, mode: mode, "id": id}
+ if "llm_response_cache" == self.namespace:
+ array_res = await self.db.query(sql, params, multirows=True)
+ res = {}
+ for row in array_res:
+ res[row["id"]] = row
+ return res
+ else:
+ return None
+
# Query by id
async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
"""Get doc_chunks data by id"""
@@ -229,33 +244,30 @@ class PGKVStorage(BaseKVStorage):
################ INSERT METHODS ################
async def upsert(self, data: Dict[str, dict]):
- left_data = {k: v for k, v in data.items() if k not in self._data}
- self._data.update(left_data)
if self.namespace == "text_chunks":
pass
elif self.namespace == "full_docs":
- for k, v in self._data.items():
+ for k, v in data.items():
upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
- data = {
+ _data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
- await self.db.execute(upsert_sql, data)
+ await self.db.execute(upsert_sql, _data)
elif self.namespace == "llm_response_cache":
- for mode, items in self._data.items():
+ for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
- data = {
+ _data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
- "return": v["return"],
+ "return_value": v["return"],
"mode": mode,
}
- await self.db.execute(upsert_sql, data)
- return left_data
+ await self.db.execute(upsert_sql, _data)
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
@@ -977,9 +989,6 @@ class PGGraphStorage(BaseGraphStorage):
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
- logger.info(
- f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
- )
query = """MATCH (source:`{src_label}`)
WITH source
@@ -1028,8 +1037,8 @@ TABLES = {
doc_name VARCHAR(1024),
content TEXT,
meta JSONB,
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updatetime TIMESTAMP,
+ create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1042,8 +1051,8 @@ TABLES = {
tokens INTEGER,
content TEXT,
content_vector VECTOR,
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updatetime TIMESTAMP,
+ create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1054,8 +1063,8 @@ TABLES = {
entity_name VARCHAR(255),
content TEXT,
content_vector VECTOR,
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updatetime TIMESTAMP,
+ create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1067,8 +1076,8 @@ TABLES = {
target_id VARCHAR(256),
content TEXT,
content_vector VECTOR,
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updatetime TIMESTAMP,
+ create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ update_time TIMESTAMP,
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
)"""
},
@@ -1078,10 +1087,10 @@ TABLES = {
id varchar(255) NOT NULL,
mode varchar(32) NOT NULL,
original_prompt TEXT,
- return TEXT,
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updatetime TIMESTAMP,
- CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
+ return_value TEXT,
+ create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ update_time TIMESTAMP,
+ CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
)"""
},
"LIGHTRAG_DOC_STATUS": {
@@ -1109,9 +1118,12 @@ SQL_TEMPLATES = {
chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""",
- "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
+ "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
""",
+ "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
+ """,
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
""",
@@ -1119,22 +1131,22 @@ SQL_TEMPLATES = {
chunk_order_index, full_doc_id
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""",
- "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
+ "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
VALUES ($1, $2, $3)
ON CONFLICT (workspace,id) DO UPDATE
- SET content = $2, updatetime = CURRENT_TIMESTAMP
+ SET content = $2, update_time = CURRENT_TIMESTAMP
""",
- "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,"return",mode)
+ "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode)
VALUES ($1, $2, $3, $4, $5)
- ON CONFLICT (workspace,id) DO UPDATE
+ ON CONFLICT (workspace,mode,id) DO UPDATE
SET original_prompt = EXCLUDED.original_prompt,
- "return"=EXCLUDED."return",
+ return_value=EXCLUDED.return_value,
mode=EXCLUDED.mode,
- updatetime = CURRENT_TIMESTAMP
+ update_time = CURRENT_TIMESTAMP
""",
"upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector)
@@ -1145,7 +1157,7 @@ SQL_TEMPLATES = {
full_doc_id=EXCLUDED.full_doc_id,
content = EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
- updatetime = CURRENT_TIMESTAMP
+ update_time = CURRENT_TIMESTAMP
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
VALUES ($1, $2, $3, $4, $5)
@@ -1153,7 +1165,7 @@ SQL_TEMPLATES = {
SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content,
content_vector=EXCLUDED.content_vector,
- updatetime=CURRENT_TIMESTAMP
+ update_time=CURRENT_TIMESTAMP
""",
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
target_id, content, content_vector)
@@ -1162,7 +1174,7 @@ SQL_TEMPLATES = {
SET source_id=EXCLUDED.source_id,
target_id=EXCLUDED.target_id,
content=EXCLUDED.content,
- content_vector=EXCLUDED.content_vector, updatetime = CURRENT_TIMESTAMP
+ content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
""",
# SQL for VectorStorage
"entities": """SELECT entity_name FROM
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index fc71508c..05de8d9f 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -176,6 +176,8 @@ class LightRAG:
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
enable_llm_cache: bool = True
+ # Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag
+ enable_llm_cache_for_entity_extract: bool = False
# extension
addon_params: dict = field(default_factory=dict)
@@ -402,6 +404,7 @@ class LightRAG:
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
+ llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
diff --git a/lightrag/operate.py b/lightrag/operate.py
index f21e41ff..b2c4d215 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -253,9 +253,13 @@ async def extract_entities(
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
+ llm_response_cache: BaseKVStorage = None,
) -> Union[BaseGraphStorage, None]:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
+ enable_llm_cache_for_entity_extract: bool = global_config[
+ "enable_llm_cache_for_entity_extract"
+ ]
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
@@ -300,6 +304,52 @@ async def extract_entities(
already_entities = 0
already_relations = 0
+ async def _user_llm_func_with_cache(
+ input_text: str, history_messages: list[dict[str, str]] = None
+ ) -> str:
+ if enable_llm_cache_for_entity_extract and llm_response_cache:
+ need_to_restore = False
+ if (
+ global_config["embedding_cache_config"]
+ and global_config["embedding_cache_config"]["enabled"]
+ ):
+ new_config = global_config.copy()
+ new_config["embedding_cache_config"] = None
+ new_config["enable_llm_cache"] = True
+ llm_response_cache.global_config = new_config
+ need_to_restore = True
+ if history_messages:
+ history = json.dumps(history_messages)
+ _prompt = history + "\n" + input_text
+ else:
+ _prompt = input_text
+
+ arg_hash = compute_args_hash(_prompt)
+ cached_return, _1, _2, _3 = await handle_cache(
+ llm_response_cache, arg_hash, _prompt, "default"
+ )
+ if need_to_restore:
+ llm_response_cache.global_config = global_config
+ if cached_return:
+ return cached_return
+
+ if history_messages:
+ res: str = await use_llm_func(
+ input_text, history_messages=history_messages
+ )
+ else:
+ res: str = await use_llm_func(input_text)
+ await save_to_cache(
+ llm_response_cache,
+ CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
+ )
+ return res
+
+ if history_messages:
+ return await use_llm_func(input_text, history_messages=history_messages)
+ else:
+ return await use_llm_func(input_text)
+
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
nonlocal already_processed, already_entities, already_relations
chunk_key = chunk_key_dp[0]
@@ -310,17 +360,19 @@ async def extract_entities(
**context_base, input_text="{input_text}"
).format(**context_base, input_text=content)
- final_result = await use_llm_func(hint_prompt)
+ final_result = await _user_llm_func_with_cache(hint_prompt)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
- glean_result = await use_llm_func(continue_prompt, history_messages=history)
+ glean_result = await _user_llm_func_with_cache(
+ continue_prompt, history_messages=history
+ )
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
final_result += glean_result
if now_glean_index == entity_extract_max_gleaning - 1:
break
- if_loop_result: str = await use_llm_func(
+ if_loop_result: str = await _user_llm_func_with_cache(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
diff --git a/lightrag/utils.py b/lightrag/utils.py
index b7c9649a..1f6bf405 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -454,7 +454,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
# For naive mode, only use simple cache matching
if mode == "naive":
- mode_cache = await hashing_kv.get_by_id(mode) or {}
+ if exists_func(hashing_kv, "get_by_mode_and_id"):
+ mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
+ else:
+ mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
@@ -488,7 +491,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return best_cached_response, None, None, None
else:
# Use regular cache
- mode_cache = await hashing_kv.get_by_id(mode) or {}
+ if exists_func(hashing_kv, "get_by_mode_and_id"):
+ mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
+ else:
+ mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
@@ -510,7 +516,13 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
return
- mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
+ if exists_func(hashing_kv, "get_by_mode_and_id"):
+ mode_cache = (
+ await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
+ or {}
+ )
+ else:
+ mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
@@ -543,3 +555,15 @@ def safe_unicode_decode(content):
)
return decoded_content
+
+
+def exists_func(obj, func_name: str) -> bool:
+ """Check if a function exists in an object or not.
+ :param obj:
+ :param func_name:
+ :return: True / False
+ """
+ if callable(getattr(obj, func_name, None)):
+ return True
+ else:
+ return False