Merge pull request #693 from danielaskdd/fix-concurrent-problem

Fixed concurrent problems for document indexing and user query
This commit is contained in:
zrguo
2025-02-02 18:26:42 +08:00
committed by GitHub
6 changed files with 141 additions and 126 deletions

View File

@@ -13,18 +13,6 @@ from fastapi import (
from typing import Dict
import threading
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
import json
import os
@@ -34,7 +22,7 @@ import logging
import argparse
import time
import re
from typing import List, Dict, Any, Optional, Union
from typing import List, Any, Optional, Union
from lightrag import LightRAG, QueryParam
from lightrag.api import __api_version__
@@ -57,8 +45,21 @@ import pipmaster as pm
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
@@ -918,6 +919,12 @@ def create_app(args):
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
)
else:
rag = LightRAG(
@@ -941,6 +948,12 @@ def create_app(args):
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
)
async def index_file(file_path: Union[str, Path]) -> None:

View File

@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
embedding = embedding[0]
logger.info(
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
)
results = self._client.query(
query=embedding,
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self):
# Protect file write operation
async with self._save_lock:
self._client.save()

View File

@@ -231,7 +231,7 @@ class LightRAG:
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
embedding_func=self.embedding_func,
)
####
@@ -275,7 +275,7 @@ class LightRAG:
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
embedding_func=self.embedding_func,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
@@ -916,7 +916,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
prompt=prompt,
)
@@ -933,7 +933,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
elif param.mode == "mix":
@@ -952,7 +952,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
else:
@@ -993,7 +993,7 @@ class LightRAG:
or self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
@@ -1024,7 +1024,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_funcne,
),
)
elif param.mode == "naive":
@@ -1040,7 +1040,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
elif param.mode == "mix":
@@ -1059,7 +1059,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
else:

View File

@@ -352,16 +352,6 @@ async def extract_entities(
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False
if (
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
):
new_config = global_config.copy()
new_config["embedding_cache_config"] = None
new_config["enable_llm_cache"] = True
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
@@ -370,10 +360,13 @@ async def extract_entities(
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type="extract",
force_llm_cache=True,
)
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
@@ -387,7 +380,12 @@ async def extract_entities(
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type="extract",
),
)
return res
@@ -740,7 +738,7 @@ async def extract_keywords_only(
# 6. Parse out JSON from the LLM response
match = re.search(r"\{.*\}", result, re.DOTALL)
if not match:
logger.error("No JSON-like structure found in the result.")
logger.error("No JSON-like structure found in the LLM respond.")
return [], []
try:
keywords_data = json.loads(match.group(0))
@@ -752,7 +750,11 @@ async def extract_keywords_only(
ll_keywords = keywords_data.get("low_level_keywords", [])
# 7. Cache only the processed keywords with cache type
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
if hl_keywords or ll_keywords:
cache_data = {
"high_level_keywords": hl_keywords,
"low_level_keywords": ll_keywords,
}
await save_to_cache(
hashing_kv,
CacheData(

View File

@@ -290,9 +290,8 @@ PROMPTS[
Question 1: {original_prompt}
Question 2: {cached_prompt}
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
1. Whether these two questions are semantically similar
2. Whether the answer to Question 2 can be used to answer Question 1
Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
Similarity score criteria:
0: Completely unrelated or answer cannot be reused, including but not limited to:
- The questions have different topics

View File

@@ -58,16 +58,9 @@ class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
concurrent_limit: int = 16
def __post_init__(self):
if self.concurrent_limit != 0:
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
else:
self._semaphore = UnlimitedSemaphore()
# concurrent_limit: int = 16
async def __call__(self, *args, **kwargs) -> np.ndarray:
async with self._semaphore:
return await self.func(*args, **kwargs)
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query')
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
Returns:
str: Hash string
"""
@@ -131,21 +124,16 @@ def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
def limit_async_func_call(max_size: int):
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
__current_size = 0
sem = asyncio.Semaphore(max_size)
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
async with sem:
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
@@ -380,6 +368,9 @@ async def get_best_cached_response(
original_prompt=None,
cache_type=None,
) -> Union[str, None]:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
embedding = np.array(embedding)
# Calculate min/max values for reconstruction
min_val = embedding.min()
max_val = embedding.max()
@@ -491,21 +486,21 @@ def dequantize_embedding(
return (quantized * scale + min_val).astype(np.float32)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
async def handle_cache(
hashing_kv,
args_hash,
prompt,
mode="default",
cache_type=None,
force_llm_cache=False,
):
"""Generic cache handling function"""
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None
# For default mode, only use simple cache matching
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
if hashing_kv is None or not (
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
):
return None, None, None, None
if mode != "default":
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
@@ -517,10 +512,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
@@ -529,12 +522,15 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt if use_llm_check else None,
original_prompt=prompt,
cache_type=cache_type,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
return None, quantized, min_val, max_val
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
# Use regular cache
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
@@ -543,7 +539,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
return None, None, None, None
@dataclass
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"cache_type": cache_data.cache_type,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,