Merge branch 'HKUDS:main' into main

This commit is contained in:
Saifeddine ALOUI
2025-02-03 11:24:08 +01:00
committed by GitHub
8 changed files with 193 additions and 132 deletions

View File

@@ -455,9 +455,38 @@ For production level scenarios you will most likely want to leverage an enterpri
* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag * If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
* Create index for AGE example: (Change below `dickens` to your graph name if necessary) * Create index for AGE example: (Change below `dickens` to your graph name if necessary)
``` ```sql
load 'age';
SET search_path = ag_catalog, "$user", public; SET search_path = ag_catalog, "$user", public;
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"')); CREATE INDEX CONCURRENTLY entity_p_idx ON dickens."Entity" (id);
CREATE INDEX CONCURRENTLY vertex_p_idx ON dickens."_ag_label_vertex" (id);
CREATE INDEX CONCURRENTLY directed_p_idx ON dickens."DIRECTED" (id);
CREATE INDEX CONCURRENTLY directed_eid_idx ON dickens."DIRECTED" (end_id);
CREATE INDEX CONCURRENTLY directed_sid_idx ON dickens."DIRECTED" (start_id);
CREATE INDEX CONCURRENTLY directed_seid_idx ON dickens."DIRECTED" (start_id,end_id);
CREATE INDEX CONCURRENTLY edge_p_idx ON dickens."_ag_label_edge" (id);
CREATE INDEX CONCURRENTLY edge_sid_idx ON dickens."_ag_label_edge" (start_id);
CREATE INDEX CONCURRENTLY edge_eid_idx ON dickens."_ag_label_edge" (end_id);
CREATE INDEX CONCURRENTLY edge_seid_idx ON dickens."_ag_label_edge" (start_id,end_id);
create INDEX CONCURRENTLY vertex_idx_node_id ON dickens."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties);
ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx;
-- drop if necessary
drop INDEX entity_p_idx;
drop INDEX vertex_p_idx;
drop INDEX directed_p_idx;
drop INDEX directed_eid_idx;
drop INDEX directed_sid_idx;
drop INDEX directed_seid_idx;
drop INDEX edge_p_idx;
drop INDEX edge_sid_idx;
drop INDEX edge_eid_idx;
drop INDEX edge_seid_idx;
drop INDEX vertex_idx_node_id;
drop INDEX entity_idx_node_id;
drop INDEX entity_node_id_gin_idx;
``` ```
* Known issue of the Apache AGE: The released versions got below issue: * Known issue of the Apache AGE: The released versions got below issue:
> You might find that the properties of the nodes/edges are empty. > You might find that the properties of the nodes/edges are empty.

View File

@@ -13,18 +13,6 @@ from fastapi import (
from typing import Dict from typing import Dict
import threading import threading
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
import json import json
import os import os
@@ -34,7 +22,7 @@ import logging
import argparse import argparse
import time import time
import re import re
from typing import List, Dict, Any, Optional, Union from typing import List, Any, Optional, Union
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.api import __api_version__ from lightrag.api import __api_version__
@@ -57,8 +45,21 @@ import pipmaster as pm
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables
load_dotenv() load_dotenv()
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text """Estimate the number of tokens in text
@@ -918,6 +919,12 @@ def create_app(args):
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
) )
else: else:
rag = LightRAG( rag = LightRAG(
@@ -941,6 +948,12 @@ def create_app(args):
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
) )
async def index_file(file_path: Union[str, Path]) -> None: async def index_file(file_path: Union[str, Path]) -> None:

View File

@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()
# Use global config value if specified, otherwise use default # Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( self.cosine_better_than_threshold = config.get(
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
logger.info( logger.info(
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}" f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
) )
results = self._client.query( results = self._client.query(
query=embedding, query=embedding,
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self): async def index_done_callback(self):
# Protect file write operation
async with self._save_lock:
self._client.save() self._client.save()

View File

@@ -30,6 +30,7 @@ from ..base import (
DocStatus, DocStatus,
DocProcessingStatus, DocProcessingStatus,
BaseGraphStorage, BaseGraphStorage,
T,
) )
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
@@ -442,6 +443,22 @@ class PGDocStatusStorage(DocStatusStorage):
existed = set([element["id"] for element in result]) existed = set([element["id"] for element in result])
return set(data) - existed return set(data) - existed
async def get_by_id(self, id: str) -> Union[T, None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
params = {"workspace": self.db.workspace, "id": id}
result = await self.db.query(sql, params, True)
if result is None:
return None
else:
return DocProcessingStatus(
content_length=result[0]["content_length"],
content_summary=result[0]["content_summary"],
status=result[0]["status"],
chunks_count=result[0]["chunks_count"],
created_at=result[0]["created_at"],
updated_at=result[0]["updated_at"],
)
async def get_status_counts(self) -> Dict[str, int]: async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count" sql = """SELECT status as "status", COUNT(1) as "count"
@@ -884,9 +901,9 @@ class PGGraphStorage(BaseGraphStorage):
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"}) MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH (n)-[r]-(connected) OPTIONAL MATCH (n)-[]-(connected)
RETURN n, r, connected RETURN n, connected
$$) AS (n agtype, r agtype, connected agtype)""" % ( $$) AS (n agtype, connected agtype)""" % (
self.graph_name, self.graph_name,
label, label,
) )

View File

@@ -231,7 +231,7 @@ class LightRAG:
self.llm_response_cache = self.key_string_value_json_storage_cls( self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
embedding_func=None, embedding_func=self.embedding_func,
) )
#### ####
@@ -275,7 +275,7 @@ class LightRAG:
else: else:
hashing_kv = self.key_string_value_json_storage_cls( hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
embedding_func=None, embedding_func=self.embedding_func,
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
@@ -373,7 +373,7 @@ class LightRAG:
doc_id doc_id
for doc_id in new_docs.keys() for doc_id in new_docs.keys()
if (current_doc := await self.doc_status.get_by_id(doc_id)) is None if (current_doc := await self.doc_status.get_by_id(doc_id)) is None
or current_doc["status"] == DocStatus.FAILED or current_doc.status == DocStatus.FAILED
} }
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
@@ -916,7 +916,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
embedding_func=None, embedding_func=self.embedding_func,
), ),
prompt=prompt, prompt=prompt,
) )
@@ -933,7 +933,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
embedding_func=None, embedding_func=self.embedding_func,
), ),
) )
elif param.mode == "mix": elif param.mode == "mix":
@@ -952,7 +952,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
embedding_func=None, embedding_func=self.embedding_func,
), ),
) )
else: else:
@@ -993,7 +993,7 @@ class LightRAG:
or self.key_string_value_json_storage_cls( or self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
embedding_func=None, embedding_func=self.embedding_func,
), ),
) )
@@ -1024,7 +1024,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
embedding_func=None, embedding_func=self.embedding_funcne,
), ),
) )
elif param.mode == "naive": elif param.mode == "naive":
@@ -1040,7 +1040,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
embedding_func=None, embedding_func=self.embedding_func,
), ),
) )
elif param.mode == "mix": elif param.mode == "mix":
@@ -1059,7 +1059,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls( else self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
global_config=asdict(self), global_config=asdict(self),
embedding_func=None, embedding_func=self.embedding_func,
), ),
) )
else: else:

View File

@@ -352,16 +352,6 @@ async def extract_entities(
input_text: str, history_messages: list[dict[str, str]] = None input_text: str, history_messages: list[dict[str, str]] = None
) -> str: ) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache: if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False
if (
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
):
new_config = global_config.copy()
new_config["embedding_cache_config"] = None
new_config["enable_llm_cache"] = True
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages: if history_messages:
history = json.dumps(history_messages, ensure_ascii=False) history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text _prompt = history + "\n" + input_text
@@ -370,10 +360,13 @@ async def extract_entities(
arg_hash = compute_args_hash(_prompt) arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache( cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, arg_hash, _prompt, "default", cache_type="default" llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type="extract",
force_llm_cache=True,
) )
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return: if cached_return:
logger.debug(f"Found cache for {arg_hash}") logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1 statistic_data["llm_cache"] += 1
@@ -387,7 +380,12 @@ async def extract_entities(
res: str = await use_llm_func(input_text) res: str = await use_llm_func(input_text)
await save_to_cache( await save_to_cache(
llm_response_cache, llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt), CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type="extract",
),
) )
return res return res
@@ -740,7 +738,7 @@ async def extract_keywords_only(
# 6. Parse out JSON from the LLM response # 6. Parse out JSON from the LLM response
match = re.search(r"\{.*\}", result, re.DOTALL) match = re.search(r"\{.*\}", result, re.DOTALL)
if not match: if not match:
logger.error("No JSON-like structure found in the result.") logger.error("No JSON-like structure found in the LLM respond.")
return [], [] return [], []
try: try:
keywords_data = json.loads(match.group(0)) keywords_data = json.loads(match.group(0))
@@ -752,7 +750,11 @@ async def extract_keywords_only(
ll_keywords = keywords_data.get("low_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", [])
# 7. Cache only the processed keywords with cache type # 7. Cache only the processed keywords with cache type
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords} if hl_keywords or ll_keywords:
cache_data = {
"high_level_keywords": hl_keywords,
"low_level_keywords": ll_keywords,
}
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(

View File

@@ -290,9 +290,8 @@ PROMPTS[
Question 1: {original_prompt} Question 1: {original_prompt}
Question 2: {cached_prompt} Question 2: {cached_prompt}
Please evaluate the following two points and provide a similarity score between 0 and 1 directly: Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
1. Whether these two questions are semantically similar
2. Whether the answer to Question 2 can be used to answer Question 1
Similarity score criteria: Similarity score criteria:
0: Completely unrelated or answer cannot be reused, including but not limited to: 0: Completely unrelated or answer cannot be reused, including but not limited to:
- The questions have different topics - The questions have different topics

View File

@@ -58,16 +58,9 @@ class EmbeddingFunc:
embedding_dim: int embedding_dim: int
max_token_size: int max_token_size: int
func: callable func: callable
concurrent_limit: int = 16 # concurrent_limit: int = 16
def __post_init__(self):
if self.concurrent_limit != 0:
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
else:
self._semaphore = UnlimitedSemaphore()
async def __call__(self, *args, **kwargs) -> np.ndarray: async def __call__(self, *args, **kwargs) -> np.ndarray:
async with self._semaphore:
return await self.func(*args, **kwargs) return await self.func(*args, **kwargs)
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
"""Compute a hash for the given arguments. """Compute a hash for the given arguments.
Args: Args:
*args: Arguments to hash *args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query') cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
Returns: Returns:
str: Hash string str: Hash string
""" """
@@ -131,21 +124,16 @@ def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest() return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): def limit_async_func_call(max_size: int):
"""Add restriction of maximum async calling times for a async func""" """Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
def final_decro(func): def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio""" sem = asyncio.Semaphore(max_size)
__current_size = 0
@wraps(func) @wraps(func)
async def wait_func(*args, **kwargs): async def wait_func(*args, **kwargs):
nonlocal __current_size async with sem:
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
__current_size -= 1
return result return result
return wait_func return wait_func
@@ -380,6 +368,9 @@ async def get_best_cached_response(
original_prompt=None, original_prompt=None,
cache_type=None, cache_type=None,
) -> Union[str, None]: ) -> Union[str, None]:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
mode_cache = await hashing_kv.get_by_id(mode) mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache: if not mode_cache:
return None return None
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2) return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple: def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
"""Quantize embedding to specified bits""" """Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
embedding = np.array(embedding)
# Calculate min/max values for reconstruction # Calculate min/max values for reconstruction
min_val = embedding.min() min_val = embedding.min()
max_val = embedding.max() max_val = embedding.max()
@@ -491,21 +486,21 @@ def dequantize_embedding(
return (quantized * scale + min_val).astype(np.float32) return (quantized * scale + min_val).astype(np.float32)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None): async def handle_cache(
hashing_kv,
args_hash,
prompt,
mode="default",
cache_type=None,
force_llm_cache=False,
):
"""Generic cache handling function""" """Generic cache handling function"""
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv is None or not (
return None, None, None, None force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
):
# For default mode, only use simple cache matching
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None return None, None, None, None
if mode != "default":
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", "embedding_cache_config",
@@ -517,10 +512,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
quantized = min_val = max_val = None quantized = min_val = max_val = None
if is_embedding_cache_enabled: if is_embedding_cache_enabled:
# Use embedding cache # Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = hashing_kv.global_config.get("llm_model_func") llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0]) quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response( best_cached_response = await get_best_cached_response(
hashing_kv, hashing_kv,
@@ -529,12 +522,15 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
mode=mode, mode=mode,
use_llm_check=use_llm_check, use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None, llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt if use_llm_check else None, original_prompt=prompt,
cache_type=cache_type, cache_type=cache_type,
) )
if best_cached_response is not None: if best_cached_response is not None:
return best_cached_response, None, None, None return best_cached_response, None, None, None
else: else:
return None, quantized, min_val, max_val
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
# Use regular cache # Use regular cache
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
@@ -543,7 +539,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
if args_hash in mode_cache: if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val return None, None, None, None
@dataclass @dataclass
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
mode_cache[cache_data.args_hash] = { mode_cache[cache_data.args_hash] = {
"return": cache_data.content, "return": cache_data.content,
"cache_type": cache_data.cache_type,
"embedding": cache_data.quantized.tobytes().hex() "embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None if cache_data.quantized is not None
else None, else None,