diff --git a/README.md b/README.md index dd215b04..950c5c5a 100644 --- a/README.md +++ b/README.md @@ -455,9 +455,38 @@ For production level scenarios you will most likely want to leverage an enterpri * If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) * Create index for AGE example: (Change below `dickens` to your graph name if necessary) - ``` + ```sql + load 'age'; SET search_path = ag_catalog, "$user", public; - CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"')); + CREATE INDEX CONCURRENTLY entity_p_idx ON dickens."Entity" (id); + CREATE INDEX CONCURRENTLY vertex_p_idx ON dickens."_ag_label_vertex" (id); + CREATE INDEX CONCURRENTLY directed_p_idx ON dickens."DIRECTED" (id); + CREATE INDEX CONCURRENTLY directed_eid_idx ON dickens."DIRECTED" (end_id); + CREATE INDEX CONCURRENTLY directed_sid_idx ON dickens."DIRECTED" (start_id); + CREATE INDEX CONCURRENTLY directed_seid_idx ON dickens."DIRECTED" (start_id,end_id); + CREATE INDEX CONCURRENTLY edge_p_idx ON dickens."_ag_label_edge" (id); + CREATE INDEX CONCURRENTLY edge_sid_idx ON dickens."_ag_label_edge" (start_id); + CREATE INDEX CONCURRENTLY edge_eid_idx ON dickens."_ag_label_edge" (end_id); + CREATE INDEX CONCURRENTLY edge_seid_idx ON dickens."_ag_label_edge" (start_id,end_id); + create INDEX CONCURRENTLY vertex_idx_node_id ON dickens."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype)); + create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype)); + CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties); + ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx; + + -- drop if necessary + drop INDEX entity_p_idx; + drop INDEX vertex_p_idx; + drop INDEX directed_p_idx; + drop INDEX directed_eid_idx; + drop INDEX directed_sid_idx; + drop INDEX directed_seid_idx; + drop INDEX edge_p_idx; + drop INDEX edge_sid_idx; + drop INDEX edge_eid_idx; + drop INDEX edge_seid_idx; + drop INDEX vertex_idx_node_id; + drop INDEX entity_idx_node_id; + drop INDEX entity_node_id_gin_idx; ``` * Known issue of the Apache AGE: The released versions got below issue: > You might find that the properties of the nodes/edges are empty. diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f7770c57..fa192f9c 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -13,18 +13,6 @@ from fastapi import ( from typing import Dict import threading -# Global progress tracker -scan_progress: Dict = { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, -} - -# Lock for thread-safe operations -progress_lock = threading.Lock() - import json import os @@ -34,7 +22,7 @@ import logging import argparse import time import re -from typing import List, Dict, Any, Optional, Union +from typing import List, Any, Optional, Union from lightrag import LightRAG, QueryParam from lightrag.api import __api_version__ @@ -57,8 +45,21 @@ import pipmaster as pm from dotenv import load_dotenv +# Load environment variables load_dotenv() +# Global progress tracker +scan_progress: Dict = { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, +} + +# Lock for thread-safe operations +progress_lock = threading.Lock() + def estimate_tokens(text: str) -> int: """Estimate the number of tokens in text @@ -918,6 +919,12 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, + enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee + embedding_cache_config={ + "enabled": True, + "similarity_threshold": 0.95, + "use_llm_check": False, + }, ) else: rag = LightRAG( @@ -941,6 +948,12 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, + enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee + embedding_cache_config={ + "enabled": True, + "similarity_threshold": 0.95, + "use_llm_check": False, + }, ) async def index_file(file_path: Union[str, Path]) -> None: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index ed272fee..6e8873fc 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage): cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): + # Initialize lock only for file operations + self._save_lock = asyncio.Lock() # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) self.cosine_better_than_threshold = config.get( @@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) embedding = embedding[0] logger.info( - f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}" + f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}" ) results = self._client.query( query=embedding, @@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self): - self._client.save() + # Protect file write operation + async with self._save_lock: + self._client.save() diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index b315abca..af62c522 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -30,6 +30,7 @@ from ..base import ( DocStatus, DocProcessingStatus, BaseGraphStorage, + T, ) if sys.platform.startswith("win"): @@ -442,6 +443,22 @@ class PGDocStatusStorage(DocStatusStorage): existed = set([element["id"] for element in result]) return set(data) - existed + async def get_by_id(self, id: str) -> Union[T, None]: + sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" + params = {"workspace": self.db.workspace, "id": id} + result = await self.db.query(sql, params, True) + if result is None: + return None + else: + return DocProcessingStatus( + content_length=result[0]["content_length"], + content_summary=result[0]["content_summary"], + status=result[0]["status"], + chunks_count=result[0]["chunks_count"], + created_at=result[0]["created_at"], + updated_at=result[0]["updated_at"], + ) + async def get_status_counts(self) -> Dict[str, int]: """Get counts of documents in each status""" sql = """SELECT status as "status", COUNT(1) as "count" @@ -884,9 +901,9 @@ class PGGraphStorage(BaseGraphStorage): query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"}) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected - $$) AS (n agtype, r agtype, connected agtype)""" % ( + OPTIONAL MATCH (n)-[]-(connected) + RETURN n, connected + $$) AS (n agtype, connected agtype)""" % ( self.graph_name, label, ) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 22db6994..f83c9e38 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -231,7 +231,7 @@ class LightRAG: self.llm_response_cache = self.key_string_value_json_storage_cls( namespace="llm_response_cache", - embedding_func=None, + embedding_func=self.embedding_func, ) #### @@ -275,7 +275,7 @@ class LightRAG: else: hashing_kv = self.key_string_value_json_storage_cls( namespace="llm_response_cache", - embedding_func=None, + embedding_func=self.embedding_func, ) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( @@ -373,7 +373,7 @@ class LightRAG: doc_id for doc_id in new_docs.keys() if (current_doc := await self.doc_status.get_by_id(doc_id)) is None - or current_doc["status"] == DocStatus.FAILED + or current_doc.status == DocStatus.FAILED } new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} @@ -916,7 +916,7 @@ class LightRAG: else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), - embedding_func=None, + embedding_func=self.embedding_func, ), prompt=prompt, ) @@ -933,7 +933,7 @@ class LightRAG: else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), - embedding_func=None, + embedding_func=self.embedding_func, ), ) elif param.mode == "mix": @@ -952,7 +952,7 @@ class LightRAG: else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), - embedding_func=None, + embedding_func=self.embedding_func, ), ) else: @@ -993,7 +993,7 @@ class LightRAG: or self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), - embedding_func=None, + embedding_func=self.embedding_func, ), ) @@ -1024,7 +1024,7 @@ class LightRAG: else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), - embedding_func=None, + embedding_func=self.embedding_funcne, ), ) elif param.mode == "naive": @@ -1040,7 +1040,7 @@ class LightRAG: else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), - embedding_func=None, + embedding_func=self.embedding_func, ), ) elif param.mode == "mix": @@ -1059,7 +1059,7 @@ class LightRAG: else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), - embedding_func=None, + embedding_func=self.embedding_func, ), ) else: diff --git a/lightrag/operate.py b/lightrag/operate.py index 6a1763c7..c8c50f61 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -352,16 +352,6 @@ async def extract_entities( input_text: str, history_messages: list[dict[str, str]] = None ) -> str: if enable_llm_cache_for_entity_extract and llm_response_cache: - need_to_restore = False - if ( - global_config["embedding_cache_config"] - and global_config["embedding_cache_config"]["enabled"] - ): - new_config = global_config.copy() - new_config["embedding_cache_config"] = None - new_config["enable_llm_cache"] = True - llm_response_cache.global_config = new_config - need_to_restore = True if history_messages: history = json.dumps(history_messages, ensure_ascii=False) _prompt = history + "\n" + input_text @@ -370,10 +360,13 @@ async def extract_entities( arg_hash = compute_args_hash(_prompt) cached_return, _1, _2, _3 = await handle_cache( - llm_response_cache, arg_hash, _prompt, "default", cache_type="default" + llm_response_cache, + arg_hash, + _prompt, + "default", + cache_type="extract", + force_llm_cache=True, ) - if need_to_restore: - llm_response_cache.global_config = global_config if cached_return: logger.debug(f"Found cache for {arg_hash}") statistic_data["llm_cache"] += 1 @@ -387,7 +380,12 @@ async def extract_entities( res: str = await use_llm_func(input_text) await save_to_cache( llm_response_cache, - CacheData(args_hash=arg_hash, content=res, prompt=_prompt), + CacheData( + args_hash=arg_hash, + content=res, + prompt=_prompt, + cache_type="extract", + ), ) return res @@ -740,7 +738,7 @@ async def extract_keywords_only( # 6. Parse out JSON from the LLM response match = re.search(r"\{.*\}", result, re.DOTALL) if not match: - logger.error("No JSON-like structure found in the result.") + logger.error("No JSON-like structure found in the LLM respond.") return [], [] try: keywords_data = json.loads(match.group(0)) @@ -752,20 +750,24 @@ async def extract_keywords_only( ll_keywords = keywords_data.get("low_level_keywords", []) # 7. Cache only the processed keywords with cache type - cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords} - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=json.dumps(cache_data), - prompt=text, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=param.mode, - cache_type="keywords", - ), - ) + if hl_keywords or ll_keywords: + cache_data = { + "high_level_keywords": hl_keywords, + "low_level_keywords": ll_keywords, + } + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=json.dumps(cache_data), + prompt=text, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=param.mode, + cache_type="keywords", + ), + ) return hl_keywords, ll_keywords diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 913f8eef..160663d9 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -290,9 +290,8 @@ PROMPTS[ Question 1: {original_prompt} Question 2: {cached_prompt} -Please evaluate the following two points and provide a similarity score between 0 and 1 directly: -1. Whether these two questions are semantically similar -2. Whether the answer to Question 2 can be used to answer Question 1 +Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly. + Similarity score criteria: 0: Completely unrelated or answer cannot be reused, including but not limited to: - The questions have different topics diff --git a/lightrag/utils.py b/lightrag/utils.py index 1ddcde50..3a69513b 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -58,17 +58,10 @@ class EmbeddingFunc: embedding_dim: int max_token_size: int func: callable - concurrent_limit: int = 16 - - def __post_init__(self): - if self.concurrent_limit != 0: - self._semaphore = asyncio.Semaphore(self.concurrent_limit) - else: - self._semaphore = UnlimitedSemaphore() + # concurrent_limit: int = 16 async def __call__(self, *args, **kwargs) -> np.ndarray: - async with self._semaphore: - return await self.func(*args, **kwargs) + return await self.func(*args, **kwargs) def locate_json_string_body_from_string(content: str) -> Union[str, None]: @@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str: """Compute a hash for the given arguments. Args: *args: Arguments to hash - cache_type: Type of cache (e.g., 'keywords', 'query') + cache_type: Type of cache (e.g., 'keywords', 'query', 'extract') Returns: str: Hash string """ @@ -131,22 +124,17 @@ def compute_mdhash_id(content, prefix: str = ""): return prefix + md5(content.encode()).hexdigest() -def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): - """Add restriction of maximum async calling times for a async func""" +def limit_async_func_call(max_size: int): + """Add restriction of maximum concurrent async calls using asyncio.Semaphore""" def final_decro(func): - """Not using async.Semaphore to aovid use nest-asyncio""" - __current_size = 0 + sem = asyncio.Semaphore(max_size) @wraps(func) async def wait_func(*args, **kwargs): - nonlocal __current_size - while __current_size >= max_size: - await asyncio.sleep(waitting_time) - __current_size += 1 - result = await func(*args, **kwargs) - __current_size -= 1 - return result + async with sem: + result = await func(*args, **kwargs) + return result return wait_func @@ -380,6 +368,9 @@ async def get_best_cached_response( original_prompt=None, cache_type=None, ) -> Union[str, None]: + logger.debug( + f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}" + ) mode_cache = await hashing_kv.get_by_id(mode) if not mode_cache: return None @@ -470,8 +461,12 @@ def cosine_similarity(v1, v2): return dot_product / (norm1 * norm2) -def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple: +def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple: """Quantize embedding to specified bits""" + # Convert list to numpy array if needed + if isinstance(embedding, list): + embedding = np.array(embedding) + # Calculate min/max values for reconstruction min_val = embedding.min() max_val = embedding.max() @@ -491,59 +486,60 @@ def dequantize_embedding( return (quantized * scale + min_val).astype(np.float32) -async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None): +async def handle_cache( + hashing_kv, + args_hash, + prompt, + mode="default", + cache_type=None, + force_llm_cache=False, +): """Generic cache handling function""" - if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"): + if hashing_kv is None or not ( + force_llm_cache or hashing_kv.global_config.get("enable_llm_cache") + ): return None, None, None, None - # For default mode, only use simple cache matching - if mode == "default": - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} - else: - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - return mode_cache[args_hash]["return"], None, None, None - return None, None, None, None - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", - {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - use_llm_check = embedding_cache_config.get("use_llm_check", False) - - quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - llm_model_func = hashing_kv.global_config.get("llm_model_func") - - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - mode=mode, - use_llm_check=use_llm_check, - llm_func=llm_model_func if use_llm_check else None, - original_prompt=prompt if use_llm_check else None, - cache_type=cache_type, + if mode != "default": + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", + {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, ) - if best_cached_response is not None: - return best_cached_response, None, None, None - else: - # Use regular cache - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} - else: - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - return mode_cache[args_hash]["return"], None, None, None + is_embedding_cache_enabled = embedding_cache_config["enabled"] + use_llm_check = embedding_cache_config.get("use_llm_check", False) - return None, quantized, min_val, max_val + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + current_embedding = await hashing_kv.embedding_func([prompt]) + llm_model_func = hashing_kv.global_config.get("llm_model_func") + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + use_llm_check=use_llm_check, + llm_func=llm_model_func if use_llm_check else None, + original_prompt=prompt, + cache_type=cache_type, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + return None, quantized, min_val, max_val + + # For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False + # Use regular cache + if exists_func(hashing_kv, "get_by_mode_and_id"): + mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} + else: + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + + return None, None, None, None @dataclass @@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): mode_cache[cache_data.args_hash] = { "return": cache_data.content, + "cache_type": cache_data.cache_type, "embedding": cache_data.quantized.tobytes().hex() if cache_data.quantized is not None else None,