Merge branch 'main' into add-multi-worker-support
This commit is contained in:
34
README.md
34
README.md
@@ -751,6 +751,40 @@ rag.delete_by_entity("Project Gutenberg")
|
|||||||
rag.delete_by_doc_id("doc_id")
|
rag.delete_by_doc_id("doc_id")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Cache
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b>Clear Cache</b> </summary>
|
||||||
|
|
||||||
|
You can clear the LLM response cache with different modes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Clear all cache
|
||||||
|
await rag.aclear_cache()
|
||||||
|
|
||||||
|
# Clear local mode cache
|
||||||
|
await rag.aclear_cache(modes=["local"])
|
||||||
|
|
||||||
|
# Clear extraction cache
|
||||||
|
await rag.aclear_cache(modes=["default"])
|
||||||
|
|
||||||
|
# Clear multiple modes
|
||||||
|
await rag.aclear_cache(modes=["local", "global", "hybrid"])
|
||||||
|
|
||||||
|
# Synchronous version
|
||||||
|
rag.clear_cache(modes=["local"])
|
||||||
|
```
|
||||||
|
|
||||||
|
Valid modes are:
|
||||||
|
- `"default"`: Extraction cache
|
||||||
|
- `"naive"`: Naive search cache
|
||||||
|
- `"local"`: Local search cache
|
||||||
|
- `"global"`: Global search cache
|
||||||
|
- `"hybrid"`: Hybrid search cache
|
||||||
|
- `"mix"`: Mix search cache
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## LightRAG init parameters
|
## LightRAG init parameters
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||||
|
|
||||||
__version__ = "1.2.2"
|
__version__ = "1.2.3"
|
||||||
__author__ = "Zirui Guo"
|
__author__ = "Zirui Guo"
|
||||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||||
|
@@ -280,10 +280,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||||
RETURN properties(r) as edge_properties
|
RETURN properties(r) as edge_properties
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""".format(
|
"""
|
||||||
entity_name_label_source=entity_name_label_source,
|
|
||||||
entity_name_label_target=entity_name_label_target,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
|
@@ -1697,3 +1697,50 @@ class LightRAG:
|
|||||||
f"Storage implementation '{storage_name}' requires the following "
|
f"Storage implementation '{storage_name}' requires the following "
|
||||||
f"environment variables: {', '.join(missing_vars)}"
|
f"environment variables: {', '.join(missing_vars)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
||||||
|
"""Clear cache data from the LLM response cache storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modes (list[str] | None): Modes of cache to clear. Options: ["default", "naive", "local", "global", "hybrid", "mix"].
|
||||||
|
"default" represents extraction cache.
|
||||||
|
If None, clears all cache.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Clear all cache
|
||||||
|
await rag.aclear_cache()
|
||||||
|
|
||||||
|
# Clear local mode cache
|
||||||
|
await rag.aclear_cache(modes=["local"])
|
||||||
|
|
||||||
|
# Clear extraction cache
|
||||||
|
await rag.aclear_cache(modes=["default"])
|
||||||
|
"""
|
||||||
|
if not self.llm_response_cache:
|
||||||
|
logger.warning("No cache storage configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
if modes and not all(mode in valid_modes for mode in modes):
|
||||||
|
raise ValueError(f"Invalid mode. Valid modes are: {valid_modes}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Reset the cache storage for specified mode
|
||||||
|
if modes:
|
||||||
|
await self.llm_response_cache.delete(modes)
|
||||||
|
logger.info(f"Cleared cache for modes: {modes}")
|
||||||
|
else:
|
||||||
|
# Clear all modes
|
||||||
|
await self.llm_response_cache.delete(valid_modes)
|
||||||
|
logger.info("Cleared all cache")
|
||||||
|
|
||||||
|
await self.llm_response_cache.index_done_callback()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error while clearing cache: {e}")
|
||||||
|
|
||||||
|
def clear_cache(self, modes: list[str] | None = None) -> None:
|
||||||
|
"""Synchronous version of aclear_cache."""
|
||||||
|
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
||||||
|
@@ -141,17 +141,18 @@ async def _handle_single_entity_extraction(
|
|||||||
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||||
return None
|
return None
|
||||||
# add this record as a node in the G
|
# add this record as a node in the G
|
||||||
entity_name = clean_str(record_attributes[1].upper())
|
entity_name = clean_str(record_attributes[1]).strip('"')
|
||||||
if not entity_name.strip():
|
if not entity_name.strip():
|
||||||
return None
|
return None
|
||||||
entity_type = clean_str(record_attributes[2].upper())
|
entity_type = clean_str(record_attributes[2]).strip('"')
|
||||||
entity_description = clean_str(record_attributes[3])
|
entity_description = clean_str(record_attributes[3]).strip('"')
|
||||||
entity_source_id = chunk_key
|
entity_source_id = chunk_key
|
||||||
return dict(
|
return dict(
|
||||||
entity_name=entity_name,
|
entity_name=entity_name,
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
description=entity_description,
|
description=entity_description,
|
||||||
source_id=entity_source_id,
|
source_id=entity_source_id,
|
||||||
|
metadata={"created_at": time.time()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -162,14 +163,15 @@ async def _handle_single_relationship_extraction(
|
|||||||
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
||||||
return None
|
return None
|
||||||
# add this record as edge
|
# add this record as edge
|
||||||
source = clean_str(record_attributes[1].upper())
|
source = clean_str(record_attributes[1]).strip('"')
|
||||||
target = clean_str(record_attributes[2].upper())
|
target = clean_str(record_attributes[2]).strip('"')
|
||||||
edge_description = clean_str(record_attributes[3])
|
edge_description = clean_str(record_attributes[3]).strip('"')
|
||||||
|
edge_keywords = clean_str(record_attributes[4]).strip('"')
|
||||||
edge_keywords = clean_str(record_attributes[4])
|
|
||||||
edge_source_id = chunk_key
|
edge_source_id = chunk_key
|
||||||
weight = (
|
weight = (
|
||||||
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
float(record_attributes[-1].strip('"'))
|
||||||
|
if is_float_regex(record_attributes[-1])
|
||||||
|
else 1.0
|
||||||
)
|
)
|
||||||
return dict(
|
return dict(
|
||||||
src_id=source,
|
src_id=source,
|
||||||
@@ -561,9 +563,13 @@ async def extract_entities(
|
|||||||
if entity_vdb is not None:
|
if entity_vdb is not None:
|
||||||
data_for_vdb = {
|
data_for_vdb = {
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
"content": dp["entity_name"] + dp["description"],
|
|
||||||
"entity_name": dp["entity_name"],
|
"entity_name": dp["entity_name"],
|
||||||
|
"entity_type": dp["entity_type"],
|
||||||
|
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||||
"source_id": dp["source_id"],
|
"source_id": dp["source_id"],
|
||||||
|
"metadata": {
|
||||||
|
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for dp in all_entities_data
|
for dp in all_entities_data
|
||||||
}
|
}
|
||||||
@@ -574,11 +580,9 @@ async def extract_entities(
|
|||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
|
"keywords": dp["keywords"],
|
||||||
|
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||||
"source_id": dp["source_id"],
|
"source_id": dp["source_id"],
|
||||||
"content": dp["keywords"]
|
|
||||||
+ dp["src_id"]
|
|
||||||
+ dp["tgt_id"]
|
|
||||||
+ dp["description"],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
"created_at": dp.get("metadata", {}).get("created_at", time.time())
|
||||||
},
|
},
|
||||||
@@ -974,7 +978,7 @@ async def mix_kg_vector_query(
|
|||||||
stream=query_param.stream,
|
stream=query_param.stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 清理响应内容
|
# Clean up response content
|
||||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||||
response = (
|
response = (
|
||||||
response.replace(sys_prompt, "")
|
response.replace(sys_prompt, "")
|
||||||
@@ -986,7 +990,7 @@ async def mix_kg_vector_query(
|
|||||||
.strip()
|
.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 7. Save cache - 只有在收集完整响应后才缓存
|
# 7. Save cache - Only cache after collecting complete response
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
@@ -1142,8 +1146,19 @@ async def _get_node_data(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# build prompt
|
# build prompt
|
||||||
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
entites_section_list = [
|
||||||
|
[
|
||||||
|
"id",
|
||||||
|
"entity",
|
||||||
|
"type",
|
||||||
|
"description",
|
||||||
|
"rank" "created_at",
|
||||||
|
]
|
||||||
|
]
|
||||||
for i, n in enumerate(node_datas):
|
for i, n in enumerate(node_datas):
|
||||||
|
created_at = n.get("created_at", "UNKNOWN")
|
||||||
|
if isinstance(created_at, (int, float)):
|
||||||
|
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||||
entites_section_list.append(
|
entites_section_list.append(
|
||||||
[
|
[
|
||||||
i,
|
i,
|
||||||
@@ -1151,6 +1166,7 @@ async def _get_node_data(
|
|||||||
n.get("entity_type", "UNKNOWN"),
|
n.get("entity_type", "UNKNOWN"),
|
||||||
n.get("description", "UNKNOWN"),
|
n.get("description", "UNKNOWN"),
|
||||||
n["rank"],
|
n["rank"],
|
||||||
|
created_at,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
entities_context = list_of_list_to_csv(entites_section_list)
|
entities_context = list_of_list_to_csv(entites_section_list)
|
||||||
@@ -1415,6 +1431,10 @@ async def _get_edge_data(
|
|||||||
|
|
||||||
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||||
for i, n in enumerate(use_entities):
|
for i, n in enumerate(use_entities):
|
||||||
|
created_at = e.get("created_at", "Unknown")
|
||||||
|
# Convert timestamp to readable format
|
||||||
|
if isinstance(created_at, (int, float)):
|
||||||
|
created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
|
||||||
entites_section_list.append(
|
entites_section_list.append(
|
||||||
[
|
[
|
||||||
i,
|
i,
|
||||||
@@ -1422,6 +1442,7 @@ async def _get_edge_data(
|
|||||||
n.get("entity_type", "UNKNOWN"),
|
n.get("entity_type", "UNKNOWN"),
|
||||||
n.get("description", "UNKNOWN"),
|
n.get("description", "UNKNOWN"),
|
||||||
n["rank"],
|
n["rank"],
|
||||||
|
created_at,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
entities_context = list_of_list_to_csv(entites_section_list)
|
entities_context = list_of_list_to_csv(entites_section_list)
|
||||||
@@ -1780,6 +1801,8 @@ async def kg_query_with_keywords(
|
|||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
stream=query_param.stream,
|
stream=query_param.stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 清理响应内容
|
||||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||||
response = (
|
response = (
|
||||||
response.replace(sys_prompt, "")
|
response.replace(sys_prompt, "")
|
||||||
@@ -1791,7 +1814,7 @@ async def kg_query_with_keywords(
|
|||||||
.strip()
|
.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to cache
|
# 7. Save cache - 只有在收集完整响应后才缓存
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
@@ -1805,4 +1828,5 @@ async def kg_query_with_keywords(
|
|||||||
cache_type="query",
|
cache_type="query",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
Reference in New Issue
Block a user