Merge pull request #693 from danielaskdd/fix-concurrent-problem
Fixed concurrent problems for document indexing and user query
This commit is contained in:
@@ -13,18 +13,6 @@ from fastapi import (
|
||||
from typing import Dict
|
||||
import threading
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
}
|
||||
|
||||
# Lock for thread-safe operations
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -34,7 +22,7 @@ import logging
|
||||
import argparse
|
||||
import time
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from typing import List, Any, Optional, Union
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.api import __api_version__
|
||||
|
||||
@@ -57,8 +45,21 @@ import pipmaster as pm
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
}
|
||||
|
||||
# Lock for thread-safe operations
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate the number of tokens in text
|
||||
@@ -918,6 +919,12 @@ def create_app(args):
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
||||
embedding_cache_config={
|
||||
"enabled": True,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
)
|
||||
else:
|
||||
rag = LightRAG(
|
||||
@@ -941,6 +948,12 @@ def create_app(args):
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
||||
embedding_cache_config={
|
||||
"enabled": True,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False,
|
||||
},
|
||||
)
|
||||
|
||||
async def index_file(file_path: Union[str, Path]) -> None:
|
||||
|
@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._save_lock = asyncio.Lock()
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
logger.info(
|
||||
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
|
||||
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
|
||||
)
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self):
|
||||
# Protect file write operation
|
||||
async with self._save_lock:
|
||||
self._client.save()
|
||||
|
@@ -231,7 +231,7 @@ class LightRAG:
|
||||
|
||||
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
####
|
||||
@@ -275,7 +275,7 @@ class LightRAG:
|
||||
else:
|
||||
hashing_kv = self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
@@ -916,7 +916,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
prompt=prompt,
|
||||
)
|
||||
@@ -933,7 +933,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
@@ -952,7 +952,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -993,7 +993,7 @@ class LightRAG:
|
||||
or self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1024,7 +1024,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_funcne,
|
||||
),
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
@@ -1040,7 +1040,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
elif param.mode == "mix":
|
||||
@@ -1059,7 +1059,7 @@ class LightRAG:
|
||||
else self.key_string_value_json_storage_cls(
|
||||
namespace="llm_response_cache",
|
||||
global_config=asdict(self),
|
||||
embedding_func=None,
|
||||
embedding_func=self.embedding_func,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@@ -352,16 +352,6 @@ async def extract_entities(
|
||||
input_text: str, history_messages: list[dict[str, str]] = None
|
||||
) -> str:
|
||||
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
||||
need_to_restore = False
|
||||
if (
|
||||
global_config["embedding_cache_config"]
|
||||
and global_config["embedding_cache_config"]["enabled"]
|
||||
):
|
||||
new_config = global_config.copy()
|
||||
new_config["embedding_cache_config"] = None
|
||||
new_config["enable_llm_cache"] = True
|
||||
llm_response_cache.global_config = new_config
|
||||
need_to_restore = True
|
||||
if history_messages:
|
||||
history = json.dumps(history_messages, ensure_ascii=False)
|
||||
_prompt = history + "\n" + input_text
|
||||
@@ -370,10 +360,13 @@ async def extract_entities(
|
||||
|
||||
arg_hash = compute_args_hash(_prompt)
|
||||
cached_return, _1, _2, _3 = await handle_cache(
|
||||
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
|
||||
llm_response_cache,
|
||||
arg_hash,
|
||||
_prompt,
|
||||
"default",
|
||||
cache_type="extract",
|
||||
force_llm_cache=True,
|
||||
)
|
||||
if need_to_restore:
|
||||
llm_response_cache.global_config = global_config
|
||||
if cached_return:
|
||||
logger.debug(f"Found cache for {arg_hash}")
|
||||
statistic_data["llm_cache"] += 1
|
||||
@@ -387,7 +380,12 @@ async def extract_entities(
|
||||
res: str = await use_llm_func(input_text)
|
||||
await save_to_cache(
|
||||
llm_response_cache,
|
||||
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
|
||||
CacheData(
|
||||
args_hash=arg_hash,
|
||||
content=res,
|
||||
prompt=_prompt,
|
||||
cache_type="extract",
|
||||
),
|
||||
)
|
||||
return res
|
||||
|
||||
@@ -740,7 +738,7 @@ async def extract_keywords_only(
|
||||
# 6. Parse out JSON from the LLM response
|
||||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
||||
if not match:
|
||||
logger.error("No JSON-like structure found in the result.")
|
||||
logger.error("No JSON-like structure found in the LLM respond.")
|
||||
return [], []
|
||||
try:
|
||||
keywords_data = json.loads(match.group(0))
|
||||
@@ -752,7 +750,11 @@ async def extract_keywords_only(
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
|
||||
# 7. Cache only the processed keywords with cache type
|
||||
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
|
||||
if hl_keywords or ll_keywords:
|
||||
cache_data = {
|
||||
"high_level_keywords": hl_keywords,
|
||||
"low_level_keywords": ll_keywords,
|
||||
}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
|
@@ -290,9 +290,8 @@ PROMPTS[
|
||||
Question 1: {original_prompt}
|
||||
Question 2: {cached_prompt}
|
||||
|
||||
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
|
||||
1. Whether these two questions are semantically similar
|
||||
2. Whether the answer to Question 2 can be used to answer Question 1
|
||||
Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
|
||||
|
||||
Similarity score criteria:
|
||||
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
||||
- The questions have different topics
|
||||
|
@@ -58,16 +58,9 @@ class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
max_token_size: int
|
||||
func: callable
|
||||
concurrent_limit: int = 16
|
||||
|
||||
def __post_init__(self):
|
||||
if self.concurrent_limit != 0:
|
||||
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
||||
else:
|
||||
self._semaphore = UnlimitedSemaphore()
|
||||
# concurrent_limit: int = 16
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
async with self._semaphore:
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
||||
"""Compute a hash for the given arguments.
|
||||
Args:
|
||||
*args: Arguments to hash
|
||||
cache_type: Type of cache (e.g., 'keywords', 'query')
|
||||
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
||||
Returns:
|
||||
str: Hash string
|
||||
"""
|
||||
@@ -131,21 +124,16 @@ def compute_mdhash_id(content, prefix: str = ""):
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||
"""Add restriction of maximum async calling times for a async func"""
|
||||
def limit_async_func_call(max_size: int):
|
||||
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
||||
|
||||
def final_decro(func):
|
||||
"""Not using async.Semaphore to aovid use nest-asyncio"""
|
||||
__current_size = 0
|
||||
sem = asyncio.Semaphore(max_size)
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, **kwargs):
|
||||
nonlocal __current_size
|
||||
while __current_size >= max_size:
|
||||
await asyncio.sleep(waitting_time)
|
||||
__current_size += 1
|
||||
async with sem:
|
||||
result = await func(*args, **kwargs)
|
||||
__current_size -= 1
|
||||
return result
|
||||
|
||||
return wait_func
|
||||
@@ -380,6 +368,9 @@ async def get_best_cached_response(
|
||||
original_prompt=None,
|
||||
cache_type=None,
|
||||
) -> Union[str, None]:
|
||||
logger.debug(
|
||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
||||
)
|
||||
mode_cache = await hashing_kv.get_by_id(mode)
|
||||
if not mode_cache:
|
||||
return None
|
||||
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
|
||||
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
||||
"""Quantize embedding to specified bits"""
|
||||
# Convert list to numpy array if needed
|
||||
if isinstance(embedding, list):
|
||||
embedding = np.array(embedding)
|
||||
|
||||
# Calculate min/max values for reconstruction
|
||||
min_val = embedding.min()
|
||||
max_val = embedding.max()
|
||||
@@ -491,21 +486,21 @@ def dequantize_embedding(
|
||||
return (quantized * scale + min_val).astype(np.float32)
|
||||
|
||||
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
|
||||
async def handle_cache(
|
||||
hashing_kv,
|
||||
args_hash,
|
||||
prompt,
|
||||
mode="default",
|
||||
cache_type=None,
|
||||
force_llm_cache=False,
|
||||
):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
||||
return None, None, None, None
|
||||
|
||||
# For default mode, only use simple cache matching
|
||||
if mode == "default":
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
if hashing_kv is None or not (
|
||||
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
||||
):
|
||||
return None, None, None, None
|
||||
|
||||
if mode != "default":
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config",
|
||||
@@ -517,10 +512,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await hashing_kv.embedding_func([prompt])
|
||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
@@ -529,12 +522,15 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
||||
mode=mode,
|
||||
use_llm_check=use_llm_check,
|
||||
llm_func=llm_model_func if use_llm_check else None,
|
||||
original_prompt=prompt if use_llm_check else None,
|
||||
original_prompt=prompt,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
|
||||
# Use regular cache
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
@@ -543,7 +539,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
return None, quantized, min_val, max_val
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
"cache_type": cache_data.cache_type,
|
||||
"embedding": cache_data.quantized.tobytes().hex()
|
||||
if cache_data.quantized is not None
|
||||
else None,
|
||||
|
Reference in New Issue
Block a user