Merge branch 'HKUDS:main' into main

This commit is contained in:
Saifeddine ALOUI
2025-02-03 11:24:08 +01:00
committed by GitHub
8 changed files with 193 additions and 132 deletions

View File

@@ -455,9 +455,38 @@ For production level scenarios you will most likely want to leverage an enterpri
* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
```
```sql
load 'age';
SET search_path = ag_catalog, "$user", public;
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
CREATE INDEX CONCURRENTLY entity_p_idx ON dickens."Entity" (id);
CREATE INDEX CONCURRENTLY vertex_p_idx ON dickens."_ag_label_vertex" (id);
CREATE INDEX CONCURRENTLY directed_p_idx ON dickens."DIRECTED" (id);
CREATE INDEX CONCURRENTLY directed_eid_idx ON dickens."DIRECTED" (end_id);
CREATE INDEX CONCURRENTLY directed_sid_idx ON dickens."DIRECTED" (start_id);
CREATE INDEX CONCURRENTLY directed_seid_idx ON dickens."DIRECTED" (start_id,end_id);
CREATE INDEX CONCURRENTLY edge_p_idx ON dickens."_ag_label_edge" (id);
CREATE INDEX CONCURRENTLY edge_sid_idx ON dickens."_ag_label_edge" (start_id);
CREATE INDEX CONCURRENTLY edge_eid_idx ON dickens."_ag_label_edge" (end_id);
CREATE INDEX CONCURRENTLY edge_seid_idx ON dickens."_ag_label_edge" (start_id,end_id);
create INDEX CONCURRENTLY vertex_idx_node_id ON dickens."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
create INDEX CONCURRENTLY entity_idx_node_id ON dickens."Entity" (ag_catalog.agtype_access_operator(properties, '"node_id"'::agtype));
CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON dickens."Entity" using gin(properties);
ALTER TABLE dickens."DIRECTED" CLUSTER ON directed_sid_idx;
-- drop if necessary
drop INDEX entity_p_idx;
drop INDEX vertex_p_idx;
drop INDEX directed_p_idx;
drop INDEX directed_eid_idx;
drop INDEX directed_sid_idx;
drop INDEX directed_seid_idx;
drop INDEX edge_p_idx;
drop INDEX edge_sid_idx;
drop INDEX edge_eid_idx;
drop INDEX edge_seid_idx;
drop INDEX vertex_idx_node_id;
drop INDEX entity_idx_node_id;
drop INDEX entity_node_id_gin_idx;
```
* Known issue of the Apache AGE: The released versions got below issue:
> You might find that the properties of the nodes/edges are empty.

View File

@@ -13,18 +13,6 @@ from fastapi import (
from typing import Dict
import threading
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
import json
import os
@@ -34,7 +22,7 @@ import logging
import argparse
import time
import re
from typing import List, Dict, Any, Optional, Union
from typing import List, Any, Optional, Union
from lightrag import LightRAG, QueryParam
from lightrag.api import __api_version__
@@ -57,8 +45,21 @@ import pipmaster as pm
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
@@ -918,6 +919,12 @@ def create_app(args):
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
)
else:
rag = LightRAG(
@@ -941,6 +948,12 @@ def create_app(args):
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
)
async def index_file(file_path: Union[str, Path]) -> None:

View File

@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
embedding = embedding[0]
logger.info(
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
)
results = self._client.query(
query=embedding,
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self):
self._client.save()
# Protect file write operation
async with self._save_lock:
self._client.save()

View File

@@ -30,6 +30,7 @@ from ..base import (
DocStatus,
DocProcessingStatus,
BaseGraphStorage,
T,
)
if sys.platform.startswith("win"):
@@ -442,6 +443,22 @@ class PGDocStatusStorage(DocStatusStorage):
existed = set([element["id"] for element in result])
return set(data) - existed
async def get_by_id(self, id: str) -> Union[T, None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
params = {"workspace": self.db.workspace, "id": id}
result = await self.db.query(sql, params, True)
if result is None:
return None
else:
return DocProcessingStatus(
content_length=result[0]["content_length"],
content_summary=result[0]["content_summary"],
status=result[0]["status"],
chunks_count=result[0]["chunks_count"],
created_at=result[0]["created_at"],
updated_at=result[0]["updated_at"],
)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count"
@@ -884,9 +901,9 @@ class PGGraphStorage(BaseGraphStorage):
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected
$$) AS (n agtype, r agtype, connected agtype)""" % (
OPTIONAL MATCH (n)-[]-(connected)
RETURN n, connected
$$) AS (n agtype, connected agtype)""" % (
self.graph_name,
label,
)

View File

@@ -231,7 +231,7 @@ class LightRAG:
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
embedding_func=self.embedding_func,
)
####
@@ -275,7 +275,7 @@ class LightRAG:
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
embedding_func=self.embedding_func,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
@@ -373,7 +373,7 @@ class LightRAG:
doc_id
for doc_id in new_docs.keys()
if (current_doc := await self.doc_status.get_by_id(doc_id)) is None
or current_doc["status"] == DocStatus.FAILED
or current_doc.status == DocStatus.FAILED
}
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
@@ -916,7 +916,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
prompt=prompt,
)
@@ -933,7 +933,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
elif param.mode == "mix":
@@ -952,7 +952,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
else:
@@ -993,7 +993,7 @@ class LightRAG:
or self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
@@ -1024,7 +1024,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_funcne,
),
)
elif param.mode == "naive":
@@ -1040,7 +1040,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
elif param.mode == "mix":
@@ -1059,7 +1059,7 @@ class LightRAG:
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
embedding_func=self.embedding_func,
),
)
else:

View File

@@ -352,16 +352,6 @@ async def extract_entities(
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False
if (
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
):
new_config = global_config.copy()
new_config["embedding_cache_config"] = None
new_config["enable_llm_cache"] = True
llm_response_cache.global_config = new_config
need_to_restore = True
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
@@ -370,10 +360,13 @@ async def extract_entities(
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type="extract",
force_llm_cache=True,
)
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
@@ -387,7 +380,12 @@ async def extract_entities(
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
CacheData(
args_hash=arg_hash,
content=res,
prompt=_prompt,
cache_type="extract",
),
)
return res
@@ -740,7 +738,7 @@ async def extract_keywords_only(
# 6. Parse out JSON from the LLM response
match = re.search(r"\{.*\}", result, re.DOTALL)
if not match:
logger.error("No JSON-like structure found in the result.")
logger.error("No JSON-like structure found in the LLM respond.")
return [], []
try:
keywords_data = json.loads(match.group(0))
@@ -752,20 +750,24 @@ async def extract_keywords_only(
ll_keywords = keywords_data.get("low_level_keywords", [])
# 7. Cache only the processed keywords with cache type
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=json.dumps(cache_data),
prompt=text,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=param.mode,
cache_type="keywords",
),
)
if hl_keywords or ll_keywords:
cache_data = {
"high_level_keywords": hl_keywords,
"low_level_keywords": ll_keywords,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=json.dumps(cache_data),
prompt=text,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=param.mode,
cache_type="keywords",
),
)
return hl_keywords, ll_keywords

View File

@@ -290,9 +290,8 @@ PROMPTS[
Question 1: {original_prompt}
Question 2: {cached_prompt}
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
1. Whether these two questions are semantically similar
2. Whether the answer to Question 2 can be used to answer Question 1
Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
Similarity score criteria:
0: Completely unrelated or answer cannot be reused, including but not limited to:
- The questions have different topics

View File

@@ -58,17 +58,10 @@ class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
concurrent_limit: int = 16
def __post_init__(self):
if self.concurrent_limit != 0:
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
else:
self._semaphore = UnlimitedSemaphore()
# concurrent_limit: int = 16
async def __call__(self, *args, **kwargs) -> np.ndarray:
async with self._semaphore:
return await self.func(*args, **kwargs)
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query')
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
Returns:
str: Hash string
"""
@@ -131,22 +124,17 @@ def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
def limit_async_func_call(max_size: int):
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
__current_size = 0
sem = asyncio.Semaphore(max_size)
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
async with sem:
result = await func(*args, **kwargs)
return result
return wait_func
@@ -380,6 +368,9 @@ async def get_best_cached_response(
original_prompt=None,
cache_type=None,
) -> Union[str, None]:
logger.debug(
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
)
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
return dot_product / (norm1 * norm2)
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
"""Quantize embedding to specified bits"""
# Convert list to numpy array if needed
if isinstance(embedding, list):
embedding = np.array(embedding)
# Calculate min/max values for reconstruction
min_val = embedding.min()
max_val = embedding.max()
@@ -491,59 +486,60 @@ def dequantize_embedding(
return (quantized * scale + min_val).astype(np.float32)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
async def handle_cache(
hashing_kv,
args_hash,
prompt,
mode="default",
cache_type=None,
force_llm_cache=False,
):
"""Generic cache handling function"""
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
if hashing_kv is None or not (
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
):
return None, None, None, None
# For default mode, only use simple cache matching
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt if use_llm_check else None,
cache_type=cache_type,
if mode != "default":
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
is_embedding_cache_enabled = embedding_cache_config["enabled"]
use_llm_check = embedding_cache_config.get("use_llm_check", False)
return None, quantized, min_val, max_val
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = hashing_kv.global_config.get("llm_model_func")
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt,
cache_type=cache_type,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
return None, quantized, min_val, max_val
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
# Use regular cache
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, None, None, None
@dataclass
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"cache_type": cache_data.cache_type,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,