Merge branch 'HKUDS:main' into main
This commit is contained in:
33
README.md
33
README.md
@@ -455,9 +455,38 @@ For production level scenarios you will most likely want to leverage an enterpri
|
||||
* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag
|
||||
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
|
||||
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
|
||||
```
|
||||
```sql
|
||||
load 'age';
|
||||
SET search_path = ag_catalog, "$user", public;
|
||||
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
|
||||
CREATE INDEX CONCURRENTLY entity_p_idx ON dickens."Entity" (id);
|
||||
CREATE INDEX CONCURRENTLY vertex_p_idx ON dickens."_ag_label_vertex" (id);
|
||||
CREATE INDEX CONCURRENTLY directed_p_idx ON dickens."DIRECTED" (id);
|
||||
CREATE INDEX CONCURRENTLY directed_eid_idx ON dickens."DIRECTED" (end_id);
|
||||
CREATE INDEX CONCURRENTLY directed_sid_idx ON dickens."DIRECTED" (start_id);
|
||||
CREATE INDEX CONCURRENTLY directed_seid_idx ON dickens."DIRECTED" (start_id,end_id);
|
||||
CREATE INDEX CONCURRENTLY edge_p_idx ON dickens."_ag_label_edge" (id);
|
||||
CREATE INDEX CONCURRENTLY edge_sid_idx ON dickens."_ag_label_edge" (start_id);
|
||||
CREATE INDEX CONCURRENTLY edge_eid_idx ON dickens."_ag_label_edge" (end_id);
|
||||
CREATE INDEX CONCURRENTLY edge_seid_idx ON dickens."_ag_label_edge" (start_id,end_id);
|
||||
create INDEX CONCURRENTLY vertex_idx_node_id ON dickens."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
|
||||
create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
|
||||
CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties);
|
||||
ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx;
|
||||
|
||||
-- drop if necessary
|
||||
drop INDEX entity_p_idx;
|
||||
drop INDEX vertex_p_idx;
|
||||
drop INDEX directed_p_idx;
|
||||
drop INDEX directed_eid_idx;
|
||||
drop INDEX directed_sid_idx;
|
||||
drop INDEX directed_seid_idx;
|
||||
drop INDEX edge_p_idx;
|
||||
drop INDEX edge_sid_idx;
|
||||
drop INDEX edge_eid_idx;
|
||||
drop INDEX edge_seid_idx;
|
||||
drop INDEX vertex_idx_node_id;
|
||||
drop INDEX entity_idx_node_id;
|
||||
drop INDEX entity_node_id_gin_idx;
|
||||
```
|
||||
* Known issue of the Apache AGE: The released versions got below issue:
|
||||
> You might find that the properties of the nodes/edges are empty.
|
||||
|
@@ -13,18 +13,6 @@ from fastapi import (
|
||||
from typing import Dict
|
||||
import threading
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
}
|
||||
|
||||
# Lock for thread-safe operations
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -34,7 +22,7 @@ import logging
|
||||
import argparse
|
||||
import time
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from typing import List, Any, Optional, Union
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.api import __api_version__
|
||||
|
||||
@@ -57,8 +45,21 @@ import pipmaster as pm
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
}
|
||||
|
||||
# Lock for thread-safe operations
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate the number of tokens in text
|
||||
@@ -918,6 +919,12 @@ def create_app(args):
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
||||
embedding_cache_config={
|
||||
"enabled": True,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
)
|
||||
else:
|
||||
rag = LightRAG(
|
||||
@@ -941,6 +948,12 @@ def create_app(args):
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
||||
embedding_cache_config={
|
||||
"enabled": True,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
)
|
||||
|
||||
async def index_file(file_path: Union[str, Path]) -> None:
|
||||
|
@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._save_lock = asyncio.Lock()
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
logger.info(
|
||||
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
|
||||
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
|
||||
)
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
# Protect file write operation
|
||||
async with self._save_lock:
|
||||
self._client.save()
|
||||
|
@@ -30,6 +30,7 @@ from ..base import (
|
||||
DocStatus,
|
||||
DocProcessingStatus,
|
||||
BaseGraphStorage,
|
||||
T,
|
||||
)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
@@ -442,6 +443,22 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
existed = set([element["id"] for element in result])
|
||||
return set(data) - existed
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[T, None]:
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
result = await self.db.query(sql, params, True)
|
||||
if result is None:
|
||||
return None
|
||||
else:
|
||||
return DocProcessingStatus(
|
||||
content_length=result[0]["content_length"],
|
||||
content_summary=result[0]["content_summary"],
|
||||
status=result[0]["status"],
|
||||
chunks_count=result[0]["chunks_count"],
|
||||
created_at=result[0]["created_at"],
|
||||
updated_at=result[0]["updated_at"],
|
||||
)
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||
@@ -884,9 +901,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected
|
||||
$$) AS (n agtype, r agtype, connected agtype)""" % (
|
||||
OPTIONAL MATCH (n)-[]-(connected)
|
||||
RETURN n, connected
|
||||
$$) AS (n agtype, connected agtype)""" % (
|
||||
self.graph_name,
|
||||
label,
|
||||
)
|
||||
|
@@ -231,7 +231,7 @@ class LightRAG:
|
||||
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
####
|
||||
@@ -275,7 +275,7 @@ class LightRAG:
|
||||
else:
|
||||
hashing_kv = self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
@@ -373,7 +373,7 @@ class LightRAG:
|
||||
doc_id
|
||||
for doc_id in new_docs.keys()
|
||||
if (current_doc := await self.doc_status.get_by_id(doc_id)) is None
|
||||
or current_doc["status"] == DocStatus.FAILED
|
||||
or current_doc.status == DocStatus.FAILED
|
||||
}
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
|
||||
@@ -916,7 +916,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
prompt=prompt,
|
||||
)
|
||||
@@ -933,7 +933,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
@@ -952,7 +952,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -993,7 +993,7 @@ class LightRAG:
|
||||
or self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1024,7 +1024,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_funcne,
|
||||
),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
@@ -1040,7 +1040,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
@@ -1059,7 +1059,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@@ -352,16 +352,6 @@ async def extract_entities(
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
) -> str:
|
||||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||||
need_to_restore = False
|
||||
if (
|
||||
global_config["embedding_cache_config"]
|
||||
and global_config["embedding_cache_config"]["enabled"]
|
||||
):
|
||||
new_config = global_config.copy()
|
||||
new_config["embedding_cache_config"] = None
|
||||
new_config["enable_llm_cache"] = True
|
||||
llm_response_cache.global_config = new_config
|
||||
need_to_restore = True
|
||||
if history_messages:
|
||||
history = json.dumps(history_messages, ensure_ascii=False)
|
||||
_prompt = history + "\n" + input_text
|
||||
@@ -370,10 +360,13 @@ async def extract_entities(
|
||||
|
||||
arg_hash = compute_args_hash(_prompt)
|
||||
cached_return, _1, _2, _3 = await handle_cache(
|
||||
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
|
||||
llm_response_cache,
|
||||
arg_hash,
|
||||
_prompt,
|
||||
"default",
|
||||
cache_type="extract",
|
||||
force_llm_cache=True,
|
||||
)
|
||||
if need_to_restore:
|
||||
llm_response_cache.global_config = global_config
|
||||
if cached_return:
|
||||
logger.debug(f"Found cache for {arg_hash}")
|
||||
statistic_data["llm_cache"] += 1
|
||||
@@ -387,7 +380,12 @@ async def extract_entities(
|
||||
res: str = await use_llm_func(input_text)
|
||||
await save_to_cache(
|
||||
llm_response_cache,
|
||||
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
|
||||
CacheData(
|
||||
args_hash=arg_hash,
|
||||
content=res,
|
||||
prompt=_prompt,
|
||||
cache_type="extract",
|
||||
),
|
||||
)
|
||||
return res
|
||||
|
||||
@@ -740,7 +738,7 @@ async def extract_keywords_only(
|
||||
# 6. Parse out JSON from the LLM response
|
||||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
||||
if not match:
|
||||
logger.error("No JSON-like structure found in the result.")
|
||||
logger.error("No JSON-like structure found in the LLM respond.")
|
||||
return [], []
|
||||
try:
|
||||
keywords_data = json.loads(match.group(0))
|
||||
@@ -752,20 +750,24 @@ async def extract_keywords_only(
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
|
||||
# 7. Cache only the processed keywords with cache type
|
||||
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=json.dumps(cache_data),
|
||||
prompt=text,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=param.mode,
|
||||
cache_type="keywords",
|
||||
),
|
||||
)
|
||||
if hl_keywords or ll_keywords:
|
||||
cache_data = {
|
||||
"high_level_keywords": hl_keywords,
|
||||
"low_level_keywords": ll_keywords,
|
||||
}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=json.dumps(cache_data),
|
||||
prompt=text,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=param.mode,
|
||||
cache_type="keywords",
|
||||
),
|
||||
)
|
||||
return hl_keywords, ll_keywords
|
||||
|
||||
|
||||
|
@@ -290,9 +290,8 @@ PROMPTS[
|
||||
Question 1: {original_prompt}
|
||||
Question 2: {cached_prompt}
|
||||
|
||||
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
|
||||
1. Whether these two questions are semantically similar
|
||||
2. Whether the answer to Question 2 can be used to answer Question 1
|
||||
Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
|
||||
|
||||
Similarity score criteria:
|
||||
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
||||
- The questions have different topics
|
||||
|
@@ -58,17 +58,10 @@ class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
max_token_size: int
|
||||
func: callable
|
||||
concurrent_limit: int = 16
|
||||
|
||||
def __post_init__(self):
|
||||
if self.concurrent_limit != 0:
|
||||
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
||||
else:
|
||||
self._semaphore = UnlimitedSemaphore()
|
||||
# concurrent_limit: int = 16
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
async with self._semaphore:
|
||||
return await self.func(*args, **kwargs)
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
||||
"""Compute a hash for the given arguments.
|
||||
Args:
|
||||
*args: Arguments to hash
|
||||
cache_type: Type of cache (e.g., 'keywords', 'query')
|
||||
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
||||
Returns:
|
||||
str: Hash string
|
||||
"""
|
||||
@@ -131,22 +124,17 @@ def compute_mdhash_id(content, prefix: str = ""):
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||
"""Add restriction of maximum async calling times for a async func"""
|
||||
def limit_async_func_call(max_size: int):
|
||||
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
||||
|
||||
def final_decro(func):
|
||||
"""Not using async.Semaphore to aovid use nest-asyncio"""
|
||||
__current_size = 0
|
||||
sem = asyncio.Semaphore(max_size)
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, **kwargs):
|
||||
nonlocal __current_size
|
||||
while __current_size >= max_size:
|
||||
await asyncio.sleep(waitting_time)
|
||||
__current_size += 1
|
||||
result = await func(*args, **kwargs)
|
||||
__current_size -= 1
|
||||
return result
|
||||
async with sem:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wait_func
|
||||
|
||||
@@ -380,6 +368,9 @@ async def get_best_cached_response(
|
||||
original_prompt=None,
|
||||
cache_type=None,
|
||||
) -> Union[str, None]:
|
||||
logger.debug(
|
||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||
)
|
||||
mode_cache = await hashing_kv.get_by_id(mode)
|
||||
if not mode_cache:
|
||||
return None
|
||||
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
|
||||
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
||||
"""Quantize embedding to specified bits"""
|
||||
# Convert list to numpy array if needed
|
||||
if isinstance(embedding, list):
|
||||
embedding = np.array(embedding)
|
||||
|
||||
# Calculate min/max values for reconstruction
|
||||
min_val = embedding.min()
|
||||
max_val = embedding.max()
|
||||
@@ -491,59 +486,60 @@ def dequantize_embedding(
|
||||
return (quantized * scale + min_val).astype(np.float32)
|
||||
|
||||
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
|
||||
async def handle_cache(
|
||||
hashing_kv,
|
||||
args_hash,
|
||||
prompt,
|
||||
mode="default",
|
||||
cache_type=None,
|
||||
force_llm_cache=False,
|
||||
):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
||||
if hashing_kv is None or not (
|
||||
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
||||
):
|
||||
return None, None, None, None
|
||||
|
||||
# For default mode, only use simple cache matching
|
||||
if mode == "default":
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
return None, None, None, None
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config",
|
||||
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
mode=mode,
|
||||
use_llm_check=use_llm_check,
|
||||
llm_func=llm_model_func if use_llm_check else None,
|
||||
original_prompt=prompt if use_llm_check else None,
|
||||
cache_type=cache_type,
|
||||
if mode != "default":
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config",
|
||||
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# Use regular cache
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||
|
||||
return None, quantized, min_val, max_val
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
current_embedding = await hashing_kv.embedding_func([prompt])
|
||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
mode=mode,
|
||||
use_llm_check=use_llm_check,
|
||||
llm_func=llm_model_func if use_llm_check else None,
|
||||
original_prompt=prompt,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
|
||||
# Use regular cache
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
"cache_type": cache_data.cache_type,
|
||||
"embedding": cache_data.quantized.tobytes().hex()
|
||||
if cache_data.quantized is not None
|
||||
else None,
|
||||
|
Reference in New Issue
Block a user