Merge branch 'HKUDS:main' into main
This commit is contained in:
33
README.md
33
README.md
@@ -455,9 +455,38 @@ For production level scenarios you will most likely want to leverage an enterpri
|
|||||||
* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag
|
* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag
|
||||||
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
|
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
|
||||||
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
|
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
|
||||||
```
|
```sql
|
||||||
|
load 'age';
|
||||||
SET search_path = ag_catalog, "$user", public;
|
SET search_path = ag_catalog, "$user", public;
|
||||||
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
|
CREATE INDEX CONCURRENTLY entity_p_idx ON dickens."Entity" (id);
|
||||||
|
CREATE INDEX CONCURRENTLY vertex_p_idx ON dickens."_ag_label_vertex" (id);
|
||||||
|
CREATE INDEX CONCURRENTLY directed_p_idx ON dickens."DIRECTED" (id);
|
||||||
|
CREATE INDEX CONCURRENTLY directed_eid_idx ON dickens."DIRECTED" (end_id);
|
||||||
|
CREATE INDEX CONCURRENTLY directed_sid_idx ON dickens."DIRECTED" (start_id);
|
||||||
|
CREATE INDEX CONCURRENTLY directed_seid_idx ON dickens."DIRECTED" (start_id,end_id);
|
||||||
|
CREATE INDEX CONCURRENTLY edge_p_idx ON dickens."_ag_label_edge" (id);
|
||||||
|
CREATE INDEX CONCURRENTLY edge_sid_idx ON dickens."_ag_label_edge" (start_id);
|
||||||
|
CREATE INDEX CONCURRENTLY edge_eid_idx ON dickens."_ag_label_edge" (end_id);
|
||||||
|
CREATE INDEX CONCURRENTLY edge_seid_idx ON dickens."_ag_label_edge" (start_id,end_id);
|
||||||
|
create INDEX CONCURRENTLY vertex_idx_node_id ON dickens."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
|
||||||
|
create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
|
||||||
|
CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties);
|
||||||
|
ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx;
|
||||||
|
|
||||||
|
-- drop if necessary
|
||||||
|
drop INDEX entity_p_idx;
|
||||||
|
drop INDEX vertex_p_idx;
|
||||||
|
drop INDEX directed_p_idx;
|
||||||
|
drop INDEX directed_eid_idx;
|
||||||
|
drop INDEX directed_sid_idx;
|
||||||
|
drop INDEX directed_seid_idx;
|
||||||
|
drop INDEX edge_p_idx;
|
||||||
|
drop INDEX edge_sid_idx;
|
||||||
|
drop INDEX edge_eid_idx;
|
||||||
|
drop INDEX edge_seid_idx;
|
||||||
|
drop INDEX vertex_idx_node_id;
|
||||||
|
drop INDEX entity_idx_node_id;
|
||||||
|
drop INDEX entity_node_id_gin_idx;
|
||||||
```
|
```
|
||||||
* Known issue of the Apache AGE: The released versions got below issue:
|
* Known issue of the Apache AGE: The released versions got below issue:
|
||||||
> You might find that the properties of the nodes/edges are empty.
|
> You might find that the properties of the nodes/edges are empty.
|
||||||
|
@@ -13,18 +13,6 @@ from fastapi import (
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
# Global progress tracker
|
|
||||||
scan_progress: Dict = {
|
|
||||||
"is_scanning": False,
|
|
||||||
"current_file": "",
|
|
||||||
"indexed_count": 0,
|
|
||||||
"total_files": 0,
|
|
||||||
"progress": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Lock for thread-safe operations
|
|
||||||
progress_lock = threading.Lock()
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -34,7 +22,7 @@ import logging
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
from typing import List, Dict, Any, Optional, Union
|
from typing import List, Any, Optional, Union
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
|
|
||||||
@@ -57,8 +45,21 @@ import pipmaster as pm
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Global progress tracker
|
||||||
|
scan_progress: Dict = {
|
||||||
|
"is_scanning": False,
|
||||||
|
"current_file": "",
|
||||||
|
"indexed_count": 0,
|
||||||
|
"total_files": 0,
|
||||||
|
"progress": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Lock for thread-safe operations
|
||||||
|
progress_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def estimate_tokens(text: str) -> int:
|
def estimate_tokens(text: str) -> int:
|
||||||
"""Estimate the number of tokens in text
|
"""Estimate the number of tokens in text
|
||||||
@@ -918,6 +919,12 @@ def create_app(args):
|
|||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
|
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
||||||
|
embedding_cache_config={
|
||||||
|
"enabled": True,
|
||||||
|
"similarity_threshold": 0.95,
|
||||||
|
"use_llm_check": False,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
@@ -941,6 +948,12 @@ def create_app(args):
|
|||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
|
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
||||||
|
embedding_cache_config={
|
||||||
|
"enabled": True,
|
||||||
|
"similarity_threshold": 0.95,
|
||||||
|
"use_llm_check": False,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def index_file(file_path: Union[str, Path]) -> None:
|
async def index_file(file_path: Union[str, Path]) -> None:
|
||||||
|
@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Initialize lock only for file operations
|
||||||
|
self._save_lock = asyncio.Lock()
|
||||||
# Use global config value if specified, otherwise use default
|
# Use global config value if specified, otherwise use default
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
self.cosine_better_than_threshold = config.get(
|
||||||
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
embedding = embedding[0]
|
embedding = embedding[0]
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
|
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
|
||||||
)
|
)
|
||||||
results = self._client.query(
|
results = self._client.query(
|
||||||
query=embedding,
|
query=embedding,
|
||||||
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
|
# Protect file write operation
|
||||||
|
async with self._save_lock:
|
||||||
self._client.save()
|
self._client.save()
|
||||||
|
@@ -30,6 +30,7 @@ from ..base import (
|
|||||||
DocStatus,
|
DocStatus,
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
|
T,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
@@ -442,6 +443,22 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||||||
existed = set([element["id"] for element in result])
|
existed = set([element["id"] for element in result])
|
||||||
return set(data) - existed
|
return set(data) - existed
|
||||||
|
|
||||||
|
async def get_by_id(self, id: str) -> Union[T, None]:
|
||||||
|
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||||
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
|
result = await self.db.query(sql, params, True)
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return DocProcessingStatus(
|
||||||
|
content_length=result[0]["content_length"],
|
||||||
|
content_summary=result[0]["content_summary"],
|
||||||
|
status=result[0]["status"],
|
||||||
|
chunks_count=result[0]["chunks_count"],
|
||||||
|
created_at=result[0]["created_at"],
|
||||||
|
updated_at=result[0]["updated_at"],
|
||||||
|
)
|
||||||
|
|
||||||
async def get_status_counts(self) -> Dict[str, int]:
|
async def get_status_counts(self) -> Dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||||
@@ -884,9 +901,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:Entity {node_id: "%s"})
|
MATCH (n:Entity {node_id: "%s"})
|
||||||
OPTIONAL MATCH (n)-[r]-(connected)
|
OPTIONAL MATCH (n)-[]-(connected)
|
||||||
RETURN n, r, connected
|
RETURN n, connected
|
||||||
$$) AS (n agtype, r agtype, connected agtype)""" % (
|
$$) AS (n agtype, connected agtype)""" % (
|
||||||
self.graph_name,
|
self.graph_name,
|
||||||
label,
|
label,
|
||||||
)
|
)
|
||||||
|
@@ -231,7 +231,7 @@ class LightRAG:
|
|||||||
|
|
||||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
####
|
####
|
||||||
@@ -275,7 +275,7 @@ class LightRAG:
|
|||||||
else:
|
else:
|
||||||
hashing_kv = self.key_string_value_json_storage_cls(
|
hashing_kv = self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
@@ -373,7 +373,7 @@ class LightRAG:
|
|||||||
doc_id
|
doc_id
|
||||||
for doc_id in new_docs.keys()
|
for doc_id in new_docs.keys()
|
||||||
if (current_doc := await self.doc_status.get_by_id(doc_id)) is None
|
if (current_doc := await self.doc_status.get_by_id(doc_id)) is None
|
||||||
or current_doc["status"] == DocStatus.FAILED
|
or current_doc.status == DocStatus.FAILED
|
||||||
}
|
}
|
||||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||||
|
|
||||||
@@ -916,7 +916,7 @@ class LightRAG:
|
|||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
)
|
)
|
||||||
@@ -933,7 +933,7 @@ class LightRAG:
|
|||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif param.mode == "mix":
|
elif param.mode == "mix":
|
||||||
@@ -952,7 +952,7 @@ class LightRAG:
|
|||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -993,7 +993,7 @@ class LightRAG:
|
|||||||
or self.key_string_value_json_storage_cls(
|
or self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1024,7 +1024,7 @@ class LightRAG:
|
|||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_funcne,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
@@ -1040,7 +1040,7 @@ class LightRAG:
|
|||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif param.mode == "mix":
|
elif param.mode == "mix":
|
||||||
@@ -1059,7 +1059,7 @@ class LightRAG:
|
|||||||
else self.key_string_value_json_storage_cls(
|
else self.key_string_value_json_storage_cls(
|
||||||
namespace="llm_response_cache",
|
namespace="llm_response_cache",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=None,
|
embedding_func=self.embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@@ -352,16 +352,6 @@ async def extract_entities(
|
|||||||
input_text: str, history_messages: list[dict[str, str]] = None
|
input_text: str, history_messages: list[dict[str, str]] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||||||
need_to_restore = False
|
|
||||||
if (
|
|
||||||
global_config["embedding_cache_config"]
|
|
||||||
and global_config["embedding_cache_config"]["enabled"]
|
|
||||||
):
|
|
||||||
new_config = global_config.copy()
|
|
||||||
new_config["embedding_cache_config"] = None
|
|
||||||
new_config["enable_llm_cache"] = True
|
|
||||||
llm_response_cache.global_config = new_config
|
|
||||||
need_to_restore = True
|
|
||||||
if history_messages:
|
if history_messages:
|
||||||
history = json.dumps(history_messages, ensure_ascii=False)
|
history = json.dumps(history_messages, ensure_ascii=False)
|
||||||
_prompt = history + "\n" + input_text
|
_prompt = history + "\n" + input_text
|
||||||
@@ -370,10 +360,13 @@ async def extract_entities(
|
|||||||
|
|
||||||
arg_hash = compute_args_hash(_prompt)
|
arg_hash = compute_args_hash(_prompt)
|
||||||
cached_return, _1, _2, _3 = await handle_cache(
|
cached_return, _1, _2, _3 = await handle_cache(
|
||||||
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
|
llm_response_cache,
|
||||||
|
arg_hash,
|
||||||
|
_prompt,
|
||||||
|
"default",
|
||||||
|
cache_type="extract",
|
||||||
|
force_llm_cache=True,
|
||||||
)
|
)
|
||||||
if need_to_restore:
|
|
||||||
llm_response_cache.global_config = global_config
|
|
||||||
if cached_return:
|
if cached_return:
|
||||||
logger.debug(f"Found cache for {arg_hash}")
|
logger.debug(f"Found cache for {arg_hash}")
|
||||||
statistic_data["llm_cache"] += 1
|
statistic_data["llm_cache"] += 1
|
||||||
@@ -387,7 +380,12 @@ async def extract_entities(
|
|||||||
res: str = await use_llm_func(input_text)
|
res: str = await use_llm_func(input_text)
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
|
CacheData(
|
||||||
|
args_hash=arg_hash,
|
||||||
|
content=res,
|
||||||
|
prompt=_prompt,
|
||||||
|
cache_type="extract",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@@ -740,7 +738,7 @@ async def extract_keywords_only(
|
|||||||
# 6. Parse out JSON from the LLM response
|
# 6. Parse out JSON from the LLM response
|
||||||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
match = re.search(r"\{.*\}", result, re.DOTALL)
|
||||||
if not match:
|
if not match:
|
||||||
logger.error("No JSON-like structure found in the result.")
|
logger.error("No JSON-like structure found in the LLM respond.")
|
||||||
return [], []
|
return [], []
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(match.group(0))
|
keywords_data = json.loads(match.group(0))
|
||||||
@@ -752,7 +750,11 @@ async def extract_keywords_only(
|
|||||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
|
|
||||||
# 7. Cache only the processed keywords with cache type
|
# 7. Cache only the processed keywords with cache type
|
||||||
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
|
if hl_keywords or ll_keywords:
|
||||||
|
cache_data = {
|
||||||
|
"high_level_keywords": hl_keywords,
|
||||||
|
"low_level_keywords": ll_keywords,
|
||||||
|
}
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
|
@@ -290,9 +290,8 @@ PROMPTS[
|
|||||||
Question 1: {original_prompt}
|
Question 1: {original_prompt}
|
||||||
Question 2: {cached_prompt}
|
Question 2: {cached_prompt}
|
||||||
|
|
||||||
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
|
Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
|
||||||
1. Whether these two questions are semantically similar
|
|
||||||
2. Whether the answer to Question 2 can be used to answer Question 1
|
|
||||||
Similarity score criteria:
|
Similarity score criteria:
|
||||||
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
||||||
- The questions have different topics
|
- The questions have different topics
|
||||||
|
@@ -58,16 +58,9 @@ class EmbeddingFunc:
|
|||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
max_token_size: int
|
max_token_size: int
|
||||||
func: callable
|
func: callable
|
||||||
concurrent_limit: int = 16
|
# concurrent_limit: int = 16
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.concurrent_limit != 0:
|
|
||||||
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
|
||||||
else:
|
|
||||||
self._semaphore = UnlimitedSemaphore()
|
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
async with self._semaphore:
|
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
|||||||
"""Compute a hash for the given arguments.
|
"""Compute a hash for the given arguments.
|
||||||
Args:
|
Args:
|
||||||
*args: Arguments to hash
|
*args: Arguments to hash
|
||||||
cache_type: Type of cache (e.g., 'keywords', 'query')
|
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
||||||
Returns:
|
Returns:
|
||||||
str: Hash string
|
str: Hash string
|
||||||
"""
|
"""
|
||||||
@@ -131,21 +124,16 @@ def compute_mdhash_id(content, prefix: str = ""):
|
|||||||
return prefix + md5(content.encode()).hexdigest()
|
return prefix + md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
def limit_async_func_call(max_size: int):
|
||||||
"""Add restriction of maximum async calling times for a async func"""
|
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
||||||
|
|
||||||
def final_decro(func):
|
def final_decro(func):
|
||||||
"""Not using async.Semaphore to aovid use nest-asyncio"""
|
sem = asyncio.Semaphore(max_size)
|
||||||
__current_size = 0
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wait_func(*args, **kwargs):
|
async def wait_func(*args, **kwargs):
|
||||||
nonlocal __current_size
|
async with sem:
|
||||||
while __current_size >= max_size:
|
|
||||||
await asyncio.sleep(waitting_time)
|
|
||||||
__current_size += 1
|
|
||||||
result = await func(*args, **kwargs)
|
result = await func(*args, **kwargs)
|
||||||
__current_size -= 1
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wait_func
|
return wait_func
|
||||||
@@ -380,6 +368,9 @@ async def get_best_cached_response(
|
|||||||
original_prompt=None,
|
original_prompt=None,
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
|
logger.debug(
|
||||||
|
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||||
|
)
|
||||||
mode_cache = await hashing_kv.get_by_id(mode)
|
mode_cache = await hashing_kv.get_by_id(mode)
|
||||||
if not mode_cache:
|
if not mode_cache:
|
||||||
return None
|
return None
|
||||||
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
|
|||||||
return dot_product / (norm1 * norm2)
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
|
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
||||||
"""Quantize embedding to specified bits"""
|
"""Quantize embedding to specified bits"""
|
||||||
|
# Convert list to numpy array if needed
|
||||||
|
if isinstance(embedding, list):
|
||||||
|
embedding = np.array(embedding)
|
||||||
|
|
||||||
# Calculate min/max values for reconstruction
|
# Calculate min/max values for reconstruction
|
||||||
min_val = embedding.min()
|
min_val = embedding.min()
|
||||||
max_val = embedding.max()
|
max_val = embedding.max()
|
||||||
@@ -491,21 +486,21 @@ def dequantize_embedding(
|
|||||||
return (quantized * scale + min_val).astype(np.float32)
|
return (quantized * scale + min_val).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
|
async def handle_cache(
|
||||||
|
hashing_kv,
|
||||||
|
args_hash,
|
||||||
|
prompt,
|
||||||
|
mode="default",
|
||||||
|
cache_type=None,
|
||||||
|
force_llm_cache=False,
|
||||||
|
):
|
||||||
"""Generic cache handling function"""
|
"""Generic cache handling function"""
|
||||||
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv is None or not (
|
||||||
return None, None, None, None
|
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
||||||
|
):
|
||||||
# For default mode, only use simple cache matching
|
|
||||||
if mode == "default":
|
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
|
||||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
|
||||||
else:
|
|
||||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
|
||||||
if args_hash in mode_cache:
|
|
||||||
return mode_cache[args_hash]["return"], None, None, None
|
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
if mode != "default":
|
||||||
# Get embedding cache configuration
|
# Get embedding cache configuration
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
"embedding_cache_config",
|
"embedding_cache_config",
|
||||||
@@ -517,10 +512,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
|||||||
quantized = min_val = max_val = None
|
quantized = min_val = max_val = None
|
||||||
if is_embedding_cache_enabled:
|
if is_embedding_cache_enabled:
|
||||||
# Use embedding cache
|
# Use embedding cache
|
||||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
current_embedding = await hashing_kv.embedding_func([prompt])
|
||||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||||
|
|
||||||
current_embedding = await embedding_model_func([prompt])
|
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
best_cached_response = await get_best_cached_response(
|
best_cached_response = await get_best_cached_response(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
@@ -529,12 +522,15 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
|||||||
mode=mode,
|
mode=mode,
|
||||||
use_llm_check=use_llm_check,
|
use_llm_check=use_llm_check,
|
||||||
llm_func=llm_model_func if use_llm_check else None,
|
llm_func=llm_model_func if use_llm_check else None,
|
||||||
original_prompt=prompt if use_llm_check else None,
|
original_prompt=prompt,
|
||||||
cache_type=cache_type,
|
cache_type=cache_type,
|
||||||
)
|
)
|
||||||
if best_cached_response is not None:
|
if best_cached_response is not None:
|
||||||
return best_cached_response, None, None, None
|
return best_cached_response, None, None, None
|
||||||
else:
|
else:
|
||||||
|
return None, quantized, min_val, max_val
|
||||||
|
|
||||||
|
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
|
||||||
# Use regular cache
|
# Use regular cache
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||||
@@ -543,7 +539,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
|||||||
if args_hash in mode_cache:
|
if args_hash in mode_cache:
|
||||||
return mode_cache[args_hash]["return"], None, None, None
|
return mode_cache[args_hash]["return"], None, None, None
|
||||||
|
|
||||||
return None, quantized, min_val, max_val
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||||||
|
|
||||||
mode_cache[cache_data.args_hash] = {
|
mode_cache[cache_data.args_hash] = {
|
||||||
"return": cache_data.content,
|
"return": cache_data.content,
|
||||||
|
"cache_type": cache_data.cache_type,
|
||||||
"embedding": cache_data.quantized.tobytes().hex()
|
"embedding": cache_data.quantized.tobytes().hex()
|
||||||
if cache_data.quantized is not None
|
if cache_data.quantized is not None
|
||||||
else None,
|
else None,
|
||||||
|
Reference in New Issue
Block a user