diff --git a/README.md b/README.md
index 84b35654..86211b35 100644
--- a/README.md
+++ b/README.md
@@ -751,6 +751,40 @@ rag.delete_by_entity("Project Gutenberg")
rag.delete_by_doc_id("doc_id")
```
+## Cache
+
+
+ Clear Cache
+
+You can clear the LLM response cache with different modes:
+
+```python
+# Clear all cache
+await rag.aclear_cache()
+
+# Clear local mode cache
+await rag.aclear_cache(modes=["local"])
+
+# Clear extraction cache
+await rag.aclear_cache(modes=["default"])
+
+# Clear multiple modes
+await rag.aclear_cache(modes=["local", "global", "hybrid"])
+
+# Synchronous version
+rag.clear_cache(modes=["local"])
+```
+
+Valid modes are:
+- `"default"`: Extraction cache
+- `"naive"`: Naive search cache
+- `"local"`: Local search cache
+- `"global"`: Global search cache
+- `"hybrid"`: Hybrid search cache
+- `"mix"`: Mix search cache
+
+
+
## LightRAG init parameters
diff --git a/lightrag/__init__.py b/lightrag/__init__.py
index 250c2ae8..2d660928 100644
--- a/lightrag/__init__.py
+++ b/lightrag/__init__.py
@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
-__version__ = "1.2.2"
+__version__ = "1.2.3"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py
index 0ddc611d..f5c2237a 100644
--- a/lightrag/kg/neo4j_impl.py
+++ b/lightrag/kg/neo4j_impl.py
@@ -280,10 +280,7 @@ class Neo4JStorage(BaseGraphStorage):
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
LIMIT 1
- """.format(
- entity_name_label_source=entity_name_label_source,
- entity_name_label_target=entity_name_label_target,
- )
+ """
result = await session.run(query)
record = await result.single()
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 44b77ae7..2cc7883d 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -1697,3 +1697,50 @@ class LightRAG:
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
+
+ async def aclear_cache(self, modes: list[str] | None = None) -> None:
+ """Clear cache data from the LLM response cache storage.
+
+ Args:
+ modes (list[str] | None): Modes of cache to clear. Options: ["default", "naive", "local", "global", "hybrid", "mix"].
+ "default" represents extraction cache.
+ If None, clears all cache.
+
+ Example:
+ # Clear all cache
+ await rag.aclear_cache()
+
+ # Clear local mode cache
+ await rag.aclear_cache(modes=["local"])
+
+ # Clear extraction cache
+ await rag.aclear_cache(modes=["default"])
+ """
+ if not self.llm_response_cache:
+ logger.warning("No cache storage configured")
+ return
+
+ valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
+
+ # Validate input
+ if modes and not all(mode in valid_modes for mode in modes):
+ raise ValueError(f"Invalid mode. Valid modes are: {valid_modes}")
+
+ try:
+ # Reset the cache storage for specified mode
+ if modes:
+ await self.llm_response_cache.delete(modes)
+ logger.info(f"Cleared cache for modes: {modes}")
+ else:
+ # Clear all modes
+ await self.llm_response_cache.delete(valid_modes)
+ logger.info("Cleared all cache")
+
+ await self.llm_response_cache.index_done_callback()
+
+ except Exception as e:
+ logger.error(f"Error while clearing cache: {e}")
+
+ def clear_cache(self, modes: list[str] | None = None) -> None:
+ """Synchronous version of aclear_cache."""
+ return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
diff --git a/lightrag/operate.py b/lightrag/operate.py
index f17a422c..7db42284 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -141,17 +141,18 @@ async def _handle_single_entity_extraction(
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
- entity_name = clean_str(record_attributes[1].upper())
+ entity_name = clean_str(record_attributes[1]).strip('"')
if not entity_name.strip():
return None
- entity_type = clean_str(record_attributes[2].upper())
- entity_description = clean_str(record_attributes[3])
+ entity_type = clean_str(record_attributes[2]).strip('"')
+ entity_description = clean_str(record_attributes[3]).strip('"')
entity_source_id = chunk_key
return dict(
entity_name=entity_name,
entity_type=entity_type,
description=entity_description,
source_id=entity_source_id,
+ metadata={"created_at": time.time()},
)
@@ -162,14 +163,15 @@ async def _handle_single_relationship_extraction(
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
- source = clean_str(record_attributes[1].upper())
- target = clean_str(record_attributes[2].upper())
- edge_description = clean_str(record_attributes[3])
-
- edge_keywords = clean_str(record_attributes[4])
+ source = clean_str(record_attributes[1]).strip('"')
+ target = clean_str(record_attributes[2]).strip('"')
+ edge_description = clean_str(record_attributes[3]).strip('"')
+ edge_keywords = clean_str(record_attributes[4]).strip('"')
edge_source_id = chunk_key
weight = (
- float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
+ float(record_attributes[-1].strip('"'))
+ if is_float_regex(record_attributes[-1])
+ else 1.0
)
return dict(
src_id=source,
@@ -561,9 +563,13 @@ async def extract_entities(
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
- "content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
+ "entity_type": dp["entity_type"],
+ "content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
+ "metadata": {
+ "created_at": dp.get("metadata", {}).get("created_at", time.time())
+ },
}
for dp in all_entities_data
}
@@ -574,11 +580,9 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
+ "keywords": dp["keywords"],
+ "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
- "content": dp["keywords"]
- + dp["src_id"]
- + dp["tgt_id"]
- + dp["description"],
"metadata": {
"created_at": dp.get("metadata", {}).get("created_at", time.time())
},
@@ -974,7 +978,7 @@ async def mix_kg_vector_query(
stream=query_param.stream,
)
- # 清理响应内容
+ # Clean up response content
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -986,7 +990,7 @@ async def mix_kg_vector_query(
.strip()
)
- # 7. Save cache - 只有在收集完整响应后才缓存
+ # 7. Save cache - Only cache after collecting complete response
await save_to_cache(
hashing_kv,
CacheData(
@@ -1142,8 +1146,19 @@ async def _get_node_data(
)
# build prompt
- entites_section_list = [["id", "entity", "type", "description", "rank"]]
+ entites_section_list = [
+ [
+ "id",
+ "entity",
+ "type",
+ "description",
+ "rank" "created_at",
+ ]
+ ]
for i, n in enumerate(node_datas):
+ created_at = n.get("created_at", "UNKNOWN")
+ if isinstance(created_at, (int, float)):
+ created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
entites_section_list.append(
[
i,
@@ -1151,6 +1166,7 @@ async def _get_node_data(
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
+ created_at,
]
)
entities_context = list_of_list_to_csv(entites_section_list)
@@ -1415,6 +1431,10 @@ async def _get_edge_data(
entites_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(use_entities):
+ created_at = e.get("created_at", "Unknown")
+ # Convert timestamp to readable format
+ if isinstance(created_at, (int, float)):
+ created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
entites_section_list.append(
[
i,
@@ -1422,6 +1442,7 @@ async def _get_edge_data(
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
+ created_at,
]
)
entities_context = list_of_list_to_csv(entites_section_list)
@@ -1780,6 +1801,8 @@ async def kg_query_with_keywords(
system_prompt=sys_prompt,
stream=query_param.stream,
)
+
+ # 清理响应内容
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -1791,18 +1814,19 @@ async def kg_query_with_keywords(
.strip()
)
- # Save to cache
- await save_to_cache(
- hashing_kv,
- CacheData(
- args_hash=args_hash,
- content=response,
- prompt=query,
- quantized=quantized,
- min_val=min_val,
- max_val=max_val,
- mode=query_param.mode,
- cache_type="query",
- ),
- )
+ # 7. Save cache - 只有在收集完整响应后才缓存
+ await save_to_cache(
+ hashing_kv,
+ CacheData(
+ args_hash=args_hash,
+ content=response,
+ prompt=query,
+ quantized=quantized,
+ min_val=min_val,
+ max_val=max_val,
+ mode=query_param.mode,
+ cache_type="query",
+ ),
+ )
+
return response