Merge pull request #1036 from danielaskdd/neo4j-add-min-degree
Refactoring Neo4j implementation and fixing storage init problem for Gunicorn
This commit is contained in:
@@ -50,9 +50,6 @@ from .auth import auth_handler
|
|||||||
# This update allows the user to put a different.env file for each lightrag folder
|
# This update allows the user to put a different.env file for each lightrag folder
|
||||||
load_dotenv(".env", override=True)
|
load_dotenv(".env", override=True)
|
||||||
|
|
||||||
# Read entity extraction cache config
|
|
||||||
enable_llm_cache = os.getenv("ENABLE_LLM_CACHE_FOR_EXTRACT", "false").lower() == "true"
|
|
||||||
|
|
||||||
# Initialize config parser
|
# Initialize config parser
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini")
|
config.read("config.ini")
|
||||||
@@ -144,23 +141,25 @@ def create_app(args):
|
|||||||
try:
|
try:
|
||||||
# Initialize database connections
|
# Initialize database connections
|
||||||
await rag.initialize_storages()
|
await rag.initialize_storages()
|
||||||
await initialize_pipeline_status()
|
|
||||||
|
|
||||||
|
await initialize_pipeline_status()
|
||||||
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
|
|
||||||
|
should_start_autoscan = False
|
||||||
|
async with get_pipeline_status_lock():
|
||||||
# Auto scan documents if enabled
|
# Auto scan documents if enabled
|
||||||
if args.auto_scan_at_startup:
|
if args.auto_scan_at_startup:
|
||||||
# Check if a task is already running (with lock protection)
|
if not pipeline_status.get("autoscanned", False):
|
||||||
pipeline_status = await get_namespace_data("pipeline_status")
|
pipeline_status["autoscanned"] = True
|
||||||
should_start_task = False
|
should_start_autoscan = True
|
||||||
async with get_pipeline_status_lock():
|
|
||||||
if not pipeline_status.get("busy", False):
|
# Only run auto scan when no other process started it first
|
||||||
should_start_task = True
|
if should_start_autoscan:
|
||||||
# Only start the task if no other task is running
|
|
||||||
if should_start_task:
|
|
||||||
# Create background task
|
# Create background task
|
||||||
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
||||||
app.state.background_tasks.add(task)
|
app.state.background_tasks.add(task)
|
||||||
task.add_done_callback(app.state.background_tasks.discard)
|
task.add_done_callback(app.state.background_tasks.discard)
|
||||||
logger.info("Auto scan task started at startup.")
|
logger.info(f"Process {os.getpid()} auto scan task started at startup.")
|
||||||
|
|
||||||
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
||||||
|
|
||||||
@@ -326,7 +325,7 @@ def create_app(args):
|
|||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||||
embedding_cache_config={
|
embedding_cache_config={
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"similarity_threshold": 0.95,
|
"similarity_threshold": 0.95,
|
||||||
@@ -355,7 +354,7 @@ def create_app(args):
|
|||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
"cosine_better_than_threshold": args.cosine_threshold
|
||||||
},
|
},
|
||||||
enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||||
embedding_cache_config={
|
embedding_cache_config={
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"similarity_threshold": 0.95,
|
"similarity_threshold": 0.95,
|
||||||
@@ -419,6 +418,7 @@ def create_app(args):
|
|||||||
"doc_status_storage": args.doc_status_storage,
|
"doc_status_storage": args.doc_status_storage,
|
||||||
"graph_storage": args.graph_storage,
|
"graph_storage": args.graph_storage,
|
||||||
"vector_storage": args.vector_storage,
|
"vector_storage": args.vector_storage,
|
||||||
|
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||||
},
|
},
|
||||||
"update_status": update_status,
|
"update_status": update_status,
|
||||||
}
|
}
|
||||||
|
@@ -362,6 +362,11 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
|||||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||||
|
|
||||||
|
# Inject LLM cache configuration
|
||||||
|
args.enable_llm_cache_for_extract = get_env_value(
|
||||||
|
"ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool
|
||||||
|
)
|
||||||
|
|
||||||
# Select Document loading tool (DOCLING, DEFAULT)
|
# Select Document loading tool (DOCLING, DEFAULT)
|
||||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||||
|
|
||||||
@@ -457,8 +462,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|||||||
ASCIIColors.yellow(f"{args.history_turns}")
|
ASCIIColors.yellow(f"{args.history_turns}")
|
||||||
ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
|
ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.cosine_threshold}")
|
ASCIIColors.yellow(f"{args.cosine_threshold}")
|
||||||
ASCIIColors.white(" └─ Top-K: ", end="")
|
ASCIIColors.white(" ├─ Top-K: ", end="")
|
||||||
ASCIIColors.yellow(f"{args.top_k}")
|
ASCIIColors.yellow(f"{args.top_k}")
|
||||||
|
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
|
||||||
|
ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
|
||||||
|
|
||||||
# System Configuration
|
# System Configuration
|
||||||
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
||||||
|
@@ -15,6 +15,10 @@ from lightrag.utils import (
|
|||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
get_namespace_data,
|
get_namespace_data,
|
||||||
get_storage_lock,
|
get_storage_lock,
|
||||||
|
get_data_init_lock,
|
||||||
|
get_update_flag,
|
||||||
|
set_all_update_flags,
|
||||||
|
clear_all_update_flags,
|
||||||
try_initialize_namespace,
|
try_initialize_namespace,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,20 +31,24 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
working_dir = self.global_config["working_dir"]
|
working_dir = self.global_config["working_dir"]
|
||||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||||
self._storage_lock = get_storage_lock()
|
|
||||||
self._data = None
|
self._data = None
|
||||||
|
self._storage_lock = None
|
||||||
|
self.storage_updated = None
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize storage data"""
|
"""Initialize storage data"""
|
||||||
|
self._storage_lock = get_storage_lock()
|
||||||
|
self.storage_updated = await get_update_flag(self.namespace)
|
||||||
|
async with get_data_init_lock():
|
||||||
# check need_init must before get_namespace_data
|
# check need_init must before get_namespace_data
|
||||||
need_init = try_initialize_namespace(self.namespace)
|
need_init = await try_initialize_namespace(self.namespace)
|
||||||
self._data = await get_namespace_data(self.namespace)
|
self._data = await get_namespace_data(self.namespace)
|
||||||
if need_init:
|
if need_init:
|
||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(loaded_data)
|
self._data.update(loaded_data)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded document status storage with {len(loaded_data)} records"
|
f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
@@ -87,18 +95,24 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
if self.storage_updated.value:
|
||||||
data_dict = (
|
data_dict = (
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
||||||
|
)
|
||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
|
await clear_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
@@ -109,9 +123,12 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
self._data.pop(doc_id, None)
|
self._data.pop(doc_id, None)
|
||||||
|
await set_all_update_flags(self.namespace)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.clear()
|
self._data.clear()
|
||||||
|
await set_all_update_flags(self.namespace)
|
||||||
|
await self.index_done_callback()
|
||||||
|
@@ -13,6 +13,10 @@ from lightrag.utils import (
|
|||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
get_namespace_data,
|
get_namespace_data,
|
||||||
get_storage_lock,
|
get_storage_lock,
|
||||||
|
get_data_init_lock,
|
||||||
|
get_update_flag,
|
||||||
|
set_all_update_flags,
|
||||||
|
clear_all_update_flags,
|
||||||
try_initialize_namespace,
|
try_initialize_namespace,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,26 +27,63 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
working_dir = self.global_config["working_dir"]
|
working_dir = self.global_config["working_dir"]
|
||||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||||
self._storage_lock = get_storage_lock()
|
|
||||||
self._data = None
|
self._data = None
|
||||||
|
self._storage_lock = None
|
||||||
|
self.storage_updated = None
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize storage data"""
|
"""Initialize storage data"""
|
||||||
|
self._storage_lock = get_storage_lock()
|
||||||
|
self.storage_updated = await get_update_flag(self.namespace)
|
||||||
|
async with get_data_init_lock():
|
||||||
# check need_init must before get_namespace_data
|
# check need_init must before get_namespace_data
|
||||||
need_init = try_initialize_namespace(self.namespace)
|
need_init = await try_initialize_namespace(self.namespace)
|
||||||
self._data = await get_namespace_data(self.namespace)
|
self._data = await get_namespace_data(self.namespace)
|
||||||
if need_init:
|
if need_init:
|
||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(loaded_data)
|
self._data.update(loaded_data)
|
||||||
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
|
|
||||||
|
# Calculate data count based on namespace
|
||||||
|
if self.namespace.endswith("cache"):
|
||||||
|
# For cache namespaces, sum the cache entries across all cache types
|
||||||
|
data_count = sum(
|
||||||
|
len(first_level_dict)
|
||||||
|
for first_level_dict in loaded_data.values()
|
||||||
|
if isinstance(first_level_dict, dict)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For non-cache namespaces, use the original count method
|
||||||
|
data_count = len(loaded_data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||||
|
)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
if self.storage_updated.value:
|
||||||
data_dict = (
|
data_dict = (
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Calculate data count based on namespace
|
||||||
|
if self.namespace.endswith("cache"):
|
||||||
|
# # For cache namespaces, sum the cache entries across all cache types
|
||||||
|
data_count = sum(
|
||||||
|
len(first_level_dict)
|
||||||
|
for first_level_dict in data_dict.values()
|
||||||
|
if isinstance(first_level_dict, dict)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For non-cache namespaces, use the original count method
|
||||||
|
data_count = len(data_dict)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||||
|
)
|
||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
|
await clear_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def get_all(self) -> dict[str, Any]:
|
async def get_all(self) -> dict[str, Any]:
|
||||||
"""Get all data from storage
|
"""Get all data from storage
|
||||||
@@ -73,15 +114,16 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
return set(keys) - set(self._data.keys())
|
return set(keys) - set(self._data.keys())
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
self._data.update(data)
|
||||||
self._data.update(left_data)
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: list[str]) -> None:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for doc_id in ids:
|
for doc_id in ids:
|
||||||
self._data.pop(doc_id, None)
|
self._data.pop(doc_id, None)
|
||||||
|
await set_all_update_flags(self.namespace)
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -7,11 +7,17 @@ from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
|||||||
|
|
||||||
|
|
||||||
# Define a direct print function for critical logs that must be visible in all processes
|
# Define a direct print function for critical logs that must be visible in all processes
|
||||||
def direct_log(message, level="INFO"):
|
def direct_log(message, level="INFO", enable_output: bool = True):
|
||||||
"""
|
"""
|
||||||
Log a message directly to stderr to ensure visibility in all processes,
|
Log a message directly to stderr to ensure visibility in all processes,
|
||||||
including the Gunicorn master process.
|
including the Gunicorn master process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to log
|
||||||
|
level: Log level (default: "INFO")
|
||||||
|
enable_output: Whether to actually output the log (default: True)
|
||||||
"""
|
"""
|
||||||
|
if enable_output:
|
||||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -32,55 +38,165 @@ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
|
|||||||
_storage_lock: Optional[LockType] = None
|
_storage_lock: Optional[LockType] = None
|
||||||
_internal_lock: Optional[LockType] = None
|
_internal_lock: Optional[LockType] = None
|
||||||
_pipeline_status_lock: Optional[LockType] = None
|
_pipeline_status_lock: Optional[LockType] = None
|
||||||
|
_graph_db_lock: Optional[LockType] = None
|
||||||
|
_data_init_lock: Optional[LockType] = None
|
||||||
|
|
||||||
|
|
||||||
class UnifiedLock(Generic[T]):
|
class UnifiedLock(Generic[T]):
|
||||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||||
|
|
||||||
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
def __init__(
|
||||||
|
self,
|
||||||
|
lock: Union[ProcessLock, asyncio.Lock],
|
||||||
|
is_async: bool,
|
||||||
|
name: str = "unnamed",
|
||||||
|
enable_logging: bool = True,
|
||||||
|
):
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
self._is_async = is_async
|
self._is_async = is_async
|
||||||
|
self._pid = os.getpid() # for debug only
|
||||||
|
self._name = name # for debug only
|
||||||
|
self._enable_logging = enable_logging # for debug only
|
||||||
|
|
||||||
async def __aenter__(self) -> "UnifiedLock[T]":
|
async def __aenter__(self) -> "UnifiedLock[T]":
|
||||||
|
try:
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
await self._lock.acquire()
|
await self._lock.acquire()
|
||||||
else:
|
else:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
except Exception as e:
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
try:
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
else:
|
else:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def __enter__(self) -> "UnifiedLock[T]":
|
def __enter__(self) -> "UnifiedLock[T]":
|
||||||
"""For backward compatibility"""
|
"""For backward compatibility"""
|
||||||
|
try:
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
except Exception as e:
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""For backward compatibility"""
|
"""For backward compatibility"""
|
||||||
|
try:
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_internal_lock() -> UnifiedLock:
|
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
|
return UnifiedLock(
|
||||||
|
lock=_internal_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="internal_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_storage_lock() -> UnifiedLock:
|
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
|
return UnifiedLock(
|
||||||
|
lock=_storage_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="storage_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_status_lock() -> UnifiedLock:
|
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
|
return UnifiedLock(
|
||||||
|
lock=_pipeline_status_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="pipeline_status_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
|
"""return unified graph database lock for ensuring atomic operations"""
|
||||||
|
return UnifiedLock(
|
||||||
|
lock=_graph_db_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="graph_db_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
|
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||||
|
return UnifiedLock(
|
||||||
|
lock=_data_init_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="data_init_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def initialize_share_data(workers: int = 1):
|
def initialize_share_data(workers: int = 1):
|
||||||
@@ -108,6 +224,8 @@ def initialize_share_data(workers: int = 1):
|
|||||||
_storage_lock, \
|
_storage_lock, \
|
||||||
_internal_lock, \
|
_internal_lock, \
|
||||||
_pipeline_status_lock, \
|
_pipeline_status_lock, \
|
||||||
|
_graph_db_lock, \
|
||||||
|
_data_init_lock, \
|
||||||
_shared_dicts, \
|
_shared_dicts, \
|
||||||
_init_flags, \
|
_init_flags, \
|
||||||
_initialized, \
|
_initialized, \
|
||||||
@@ -120,14 +238,16 @@ def initialize_share_data(workers: int = 1):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
_manager = Manager()
|
|
||||||
_workers = workers
|
_workers = workers
|
||||||
|
|
||||||
if workers > 1:
|
if workers > 1:
|
||||||
is_multiprocess = True
|
is_multiprocess = True
|
||||||
|
_manager = Manager()
|
||||||
_internal_lock = _manager.Lock()
|
_internal_lock = _manager.Lock()
|
||||||
_storage_lock = _manager.Lock()
|
_storage_lock = _manager.Lock()
|
||||||
_pipeline_status_lock = _manager.Lock()
|
_pipeline_status_lock = _manager.Lock()
|
||||||
|
_graph_db_lock = _manager.Lock()
|
||||||
|
_data_init_lock = _manager.Lock()
|
||||||
_shared_dicts = _manager.dict()
|
_shared_dicts = _manager.dict()
|
||||||
_init_flags = _manager.dict()
|
_init_flags = _manager.dict()
|
||||||
_update_flags = _manager.dict()
|
_update_flags = _manager.dict()
|
||||||
@@ -139,6 +259,8 @@ def initialize_share_data(workers: int = 1):
|
|||||||
_internal_lock = asyncio.Lock()
|
_internal_lock = asyncio.Lock()
|
||||||
_storage_lock = asyncio.Lock()
|
_storage_lock = asyncio.Lock()
|
||||||
_pipeline_status_lock = asyncio.Lock()
|
_pipeline_status_lock = asyncio.Lock()
|
||||||
|
_graph_db_lock = asyncio.Lock()
|
||||||
|
_data_init_lock = asyncio.Lock()
|
||||||
_shared_dicts = {}
|
_shared_dicts = {}
|
||||||
_init_flags = {}
|
_init_flags = {}
|
||||||
_update_flags = {}
|
_update_flags = {}
|
||||||
@@ -164,6 +286,7 @@ async def initialize_pipeline_status():
|
|||||||
history_messages = _manager.list() if is_multiprocess else []
|
history_messages = _manager.list() if is_multiprocess else []
|
||||||
pipeline_namespace.update(
|
pipeline_namespace.update(
|
||||||
{
|
{
|
||||||
|
"autoscanned": False, # Auto-scan started
|
||||||
"busy": False, # Control concurrent processes
|
"busy": False, # Control concurrent processes
|
||||||
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
|
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
|
||||||
"job_start": None, # Job start time
|
"job_start": None, # Job start time
|
||||||
@@ -200,7 +323,12 @@ async def get_update_flag(namespace: str):
|
|||||||
if is_multiprocess and _manager is not None:
|
if is_multiprocess and _manager is not None:
|
||||||
new_update_flag = _manager.Value("b", False)
|
new_update_flag = _manager.Value("b", False)
|
||||||
else:
|
else:
|
||||||
new_update_flag = False
|
# Create a simple mutable object to store boolean value for compatibility with mutiprocess
|
||||||
|
class MutableBoolean:
|
||||||
|
def __init__(self, initial_value=False):
|
||||||
|
self.value = initial_value
|
||||||
|
|
||||||
|
new_update_flag = MutableBoolean(False)
|
||||||
|
|
||||||
_update_flags[namespace].append(new_update_flag)
|
_update_flags[namespace].append(new_update_flag)
|
||||||
return new_update_flag
|
return new_update_flag
|
||||||
@@ -220,7 +348,26 @@ async def set_all_update_flags(namespace: str):
|
|||||||
if is_multiprocess:
|
if is_multiprocess:
|
||||||
_update_flags[namespace][i].value = True
|
_update_flags[namespace][i].value = True
|
||||||
else:
|
else:
|
||||||
_update_flags[namespace][i] = True
|
# Use .value attribute instead of direct assignment
|
||||||
|
_update_flags[namespace][i].value = True
|
||||||
|
|
||||||
|
|
||||||
|
async def clear_all_update_flags(namespace: str):
|
||||||
|
"""Clear all update flag of namespace indicating all workers need to reload data from files"""
|
||||||
|
global _update_flags
|
||||||
|
if _update_flags is None:
|
||||||
|
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||||
|
|
||||||
|
async with get_internal_lock():
|
||||||
|
if namespace not in _update_flags:
|
||||||
|
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||||
|
# Update flags for both modes
|
||||||
|
for i in range(len(_update_flags[namespace])):
|
||||||
|
if is_multiprocess:
|
||||||
|
_update_flags[namespace][i].value = False
|
||||||
|
else:
|
||||||
|
# Use .value attribute instead of direct assignment
|
||||||
|
_update_flags[namespace][i].value = False
|
||||||
|
|
||||||
|
|
||||||
async def get_all_update_flags_status() -> Dict[str, list]:
|
async def get_all_update_flags_status() -> Dict[str, list]:
|
||||||
@@ -247,7 +394,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def try_initialize_namespace(namespace: str) -> bool:
|
async def try_initialize_namespace(namespace: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if the current worker(process) gets initialization permission for loading data later.
|
Returns True if the current worker(process) gets initialization permission for loading data later.
|
||||||
The worker does not get the permission is prohibited to load data from files.
|
The worker does not get the permission is prohibited to load data from files.
|
||||||
@@ -257,6 +404,7 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
if _init_flags is None:
|
if _init_flags is None:
|
||||||
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
||||||
|
|
||||||
|
async with get_internal_lock():
|
||||||
if namespace not in _init_flags:
|
if namespace not in _init_flags:
|
||||||
_init_flags[namespace] = True
|
_init_flags[namespace] = True
|
||||||
direct_log(
|
direct_log(
|
||||||
@@ -266,6 +414,7 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
direct_log(
|
direct_log(
|
||||||
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -304,6 +453,8 @@ def finalize_share_data():
|
|||||||
_storage_lock, \
|
_storage_lock, \
|
||||||
_internal_lock, \
|
_internal_lock, \
|
||||||
_pipeline_status_lock, \
|
_pipeline_status_lock, \
|
||||||
|
_graph_db_lock, \
|
||||||
|
_data_init_lock, \
|
||||||
_shared_dicts, \
|
_shared_dicts, \
|
||||||
_init_flags, \
|
_init_flags, \
|
||||||
_initialized, \
|
_initialized, \
|
||||||
@@ -369,6 +520,8 @@ def finalize_share_data():
|
|||||||
_storage_lock = None
|
_storage_lock = None
|
||||||
_internal_lock = None
|
_internal_lock = None
|
||||||
_pipeline_status_lock = None
|
_pipeline_status_lock = None
|
||||||
|
_graph_db_lock = None
|
||||||
|
_data_init_lock = None
|
||||||
_update_flags = None
|
_update_flags = None
|
||||||
|
|
||||||
direct_log(f"Process {os.getpid()} storage data finalization complete")
|
direct_log(f"Process {os.getpid()} storage data finalization complete")
|
||||||
|
@@ -354,6 +354,9 @@ class LightRAG:
|
|||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
|
global_config=asdict(
|
||||||
|
self
|
||||||
|
), # Add global_config to ensure cache works properly
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -404,18 +407,8 @@ class LightRAG:
|
|||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.llm_response_cache and hasattr(
|
# Directly use llm_response_cache, don't create a new object
|
||||||
self.llm_response_cache, "global_config"
|
|
||||||
):
|
|
||||||
hashing_kv = self.llm_response_cache
|
hashing_kv = self.llm_response_cache
|
||||||
else:
|
|
||||||
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(
|
partial(
|
||||||
@@ -590,6 +583,7 @@ class LightRAG:
|
|||||||
split_by_character, split_by_character_only
|
split_by_character, split_by_character_only
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: deprecated, use insert instead
|
||||||
def insert_custom_chunks(
|
def insert_custom_chunks(
|
||||||
self,
|
self,
|
||||||
full_text: str,
|
full_text: str,
|
||||||
@@ -601,6 +595,7 @@ class LightRAG:
|
|||||||
self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
|
self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: deprecated, use ainsert instead
|
||||||
async def ainsert_custom_chunks(
|
async def ainsert_custom_chunks(
|
||||||
self, full_text: str, text_chunks: list[str], doc_id: str | None = None
|
self, full_text: str, text_chunks: list[str], doc_id: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -892,7 +887,9 @@ class LightRAG:
|
|||||||
self.chunks_vdb.upsert(chunks)
|
self.chunks_vdb.upsert(chunks)
|
||||||
)
|
)
|
||||||
entity_relation_task = asyncio.create_task(
|
entity_relation_task = asyncio.create_task(
|
||||||
self._process_entity_relation_graph(chunks)
|
self._process_entity_relation_graph(
|
||||||
|
chunks, pipeline_status, pipeline_status_lock
|
||||||
|
)
|
||||||
)
|
)
|
||||||
full_docs_task = asyncio.create_task(
|
full_docs_task = asyncio.create_task(
|
||||||
self.full_docs.upsert(
|
self.full_docs.upsert(
|
||||||
@@ -1007,21 +1004,27 @@ class LightRAG:
|
|||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
async def _process_entity_relation_graph(
|
||||||
|
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
await extract_entities(
|
await extract_entities(
|
||||||
chunk,
|
chunk,
|
||||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
entity_vdb=self.entities_vdb,
|
entity_vdb=self.entities_vdb,
|
||||||
relationships_vdb=self.relationships_vdb,
|
relationships_vdb=self.relationships_vdb,
|
||||||
llm_response_cache=self.llm_response_cache,
|
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
|
pipeline_status=pipeline_status,
|
||||||
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
|
llm_response_cache=self.llm_response_cache,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to extract entities and relationships")
|
logger.error("Failed to extract entities and relationships")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _insert_done(self) -> None:
|
async def _insert_done(
|
||||||
|
self, pipeline_status=None, pipeline_status_lock=None
|
||||||
|
) -> None:
|
||||||
tasks = [
|
tasks = [
|
||||||
cast(StorageNameSpace, storage_inst).index_done_callback()
|
cast(StorageNameSpace, storage_inst).index_done_callback()
|
||||||
for storage_inst in [ # type: ignore
|
for storage_inst in [ # type: ignore
|
||||||
@@ -1040,10 +1043,8 @@ class LightRAG:
|
|||||||
log_message = "All Insert done"
|
log_message = "All Insert done"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
|
||||||
# 获取 pipeline_status 并更新 latest_message 和 history_messages
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
from lightrag.kg.shared_storage import get_namespace_data
|
async with pipeline_status_lock:
|
||||||
|
|
||||||
pipeline_status = await get_namespace_data("pipeline_status")
|
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
@@ -1260,16 +1261,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
@@ -1279,16 +1271,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
elif param.mode == "mix":
|
elif param.mode == "mix":
|
||||||
@@ -1301,16 +1284,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1344,14 +1318,7 @@ class LightRAG:
|
|||||||
text=query,
|
text=query,
|
||||||
param=param,
|
param=param,
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
or self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
param.hl_keywords = hl_keywords
|
param.hl_keywords = hl_keywords
|
||||||
@@ -1375,16 +1342,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
response = await naive_query(
|
response = await naive_query(
|
||||||
@@ -1393,16 +1351,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif param.mode == "mix":
|
elif param.mode == "mix":
|
||||||
response = await mix_kg_vector_query(
|
response = await mix_kg_vector_query(
|
||||||
@@ -1414,16 +1363,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
hashing_kv=self.llm_response_cache
|
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
||||||
if self.llm_response_cache
|
|
||||||
and hasattr(self.llm_response_cache, "global_config")
|
|
||||||
else self.key_string_value_json_storage_cls(
|
|
||||||
namespace=make_namespace(
|
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
|
||||||
),
|
|
||||||
global_config=asdict(self),
|
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {param.mode}")
|
||||||
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
from typing import Any, AsyncIterator
|
from typing import Any, AsyncIterator
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
@@ -220,6 +221,7 @@ async def _merge_nodes_then_upsert(
|
|||||||
entity_name, description, global_config
|
entity_name, description, global_config
|
||||||
)
|
)
|
||||||
node_data = dict(
|
node_data = dict(
|
||||||
|
entity_id=entity_name,
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
description=description,
|
description=description,
|
||||||
source_id=source_id,
|
source_id=source_id,
|
||||||
@@ -301,6 +303,7 @@ async def _merge_edges_then_upsert(
|
|||||||
await knowledge_graph_inst.upsert_node(
|
await knowledge_graph_inst.upsert_node(
|
||||||
need_insert_id,
|
need_insert_id,
|
||||||
node_data={
|
node_data={
|
||||||
|
"entity_id": need_insert_id,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
"description": description,
|
"description": description,
|
||||||
"entity_type": "UNKNOWN",
|
"entity_type": "UNKNOWN",
|
||||||
@@ -337,11 +340,10 @@ async def extract_entities(
|
|||||||
entity_vdb: BaseVectorStorage,
|
entity_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
|
pipeline_status: dict = None,
|
||||||
|
pipeline_status_lock=None,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
from lightrag.kg.shared_storage import get_namespace_data
|
|
||||||
|
|
||||||
pipeline_status = await get_namespace_data("pipeline_status")
|
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||||
@@ -400,6 +402,7 @@ async def extract_entities(
|
|||||||
else:
|
else:
|
||||||
_prompt = input_text
|
_prompt = input_text
|
||||||
|
|
||||||
|
# TODO: add cache_type="extract"
|
||||||
arg_hash = compute_args_hash(_prompt)
|
arg_hash = compute_args_hash(_prompt)
|
||||||
cached_return, _1, _2, _3 = await handle_cache(
|
cached_return, _1, _2, _3 = await handle_cache(
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
@@ -407,7 +410,6 @@ async def extract_entities(
|
|||||||
_prompt,
|
_prompt,
|
||||||
"default",
|
"default",
|
||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
force_llm_cache=True,
|
|
||||||
)
|
)
|
||||||
if cached_return:
|
if cached_return:
|
||||||
logger.debug(f"Found cache for {arg_hash}")
|
logger.debug(f"Found cache for {arg_hash}")
|
||||||
@@ -504,6 +506,8 @@ async def extract_entities(
|
|||||||
relations_count = len(maybe_edges)
|
relations_count = len(maybe_edges)
|
||||||
log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
|
log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
return dict(maybe_nodes), dict(maybe_edges)
|
return dict(maybe_nodes), dict(maybe_edges)
|
||||||
@@ -519,6 +523,12 @@ async def extract_entities(
|
|||||||
for k, v in m_edges.items():
|
for k, v in m_edges.items():
|
||||||
maybe_edges[tuple(sorted(k))].extend(v)
|
maybe_edges[tuple(sorted(k))].extend(v)
|
||||||
|
|
||||||
|
from .kg.shared_storage import get_graph_db_lock
|
||||||
|
|
||||||
|
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||||
|
|
||||||
|
# Ensure that nodes and edges are merged and upserted atomically
|
||||||
|
async with graph_db_lock:
|
||||||
all_entities_data = await asyncio.gather(
|
all_entities_data = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||||
@@ -528,7 +538,9 @@ async def extract_entities(
|
|||||||
|
|
||||||
all_relationships_data = await asyncio.gather(
|
all_relationships_data = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
|
_merge_edges_then_upsert(
|
||||||
|
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||||
|
)
|
||||||
for k, v in maybe_edges.items()
|
for k, v in maybe_edges.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -536,6 +548,8 @@ async def extract_entities(
|
|||||||
if not (all_entities_data or all_relationships_data):
|
if not (all_entities_data or all_relationships_data):
|
||||||
log_message = "Didn't extract any entities and relationships."
|
log_message = "Didn't extract any entities and relationships."
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
return
|
return
|
||||||
@@ -543,16 +557,22 @@ async def extract_entities(
|
|||||||
if not all_entities_data:
|
if not all_entities_data:
|
||||||
log_message = "Didn't extract any entities"
|
log_message = "Didn't extract any entities"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
if not all_relationships_data:
|
if not all_relationships_data:
|
||||||
log_message = "Didn't extract any relationships"
|
log_message = "Didn't extract any relationships"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
|
log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
verbose_debug(
|
verbose_debug(
|
||||||
@@ -1017,6 +1037,7 @@ async def _build_query_context(
|
|||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
|
logger.info(f"Process {os.getpid()} buidling query context...")
|
||||||
if query_param.mode == "local":
|
if query_param.mode == "local":
|
||||||
entities_context, relations_context, text_units_context = await _get_node_data(
|
entities_context, relations_context, text_units_context = await _get_node_data(
|
||||||
ll_keywords,
|
ll_keywords,
|
||||||
|
@@ -633,15 +633,15 @@ async def handle_cache(
|
|||||||
prompt,
|
prompt,
|
||||||
mode="default",
|
mode="default",
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
force_llm_cache=False,
|
|
||||||
):
|
):
|
||||||
"""Generic cache handling function"""
|
"""Generic cache handling function"""
|
||||||
if hashing_kv is None or not (
|
if hashing_kv is None:
|
||||||
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
return None, None, None, None
|
||||||
):
|
|
||||||
|
if mode != "default": # handle cache for all type of query
|
||||||
|
if not hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
if mode != "default":
|
|
||||||
# Get embedding cache configuration
|
# Get embedding cache configuration
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
"embedding_cache_config",
|
"embedding_cache_config",
|
||||||
@@ -651,8 +651,7 @@ async def handle_cache(
|
|||||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||||
|
|
||||||
quantized = min_val = max_val = None
|
quantized = min_val = max_val = None
|
||||||
if is_embedding_cache_enabled:
|
if is_embedding_cache_enabled: # Use embedding simularity to match cache
|
||||||
# Use embedding cache
|
|
||||||
current_embedding = await hashing_kv.embedding_func([prompt])
|
current_embedding = await hashing_kv.embedding_func([prompt])
|
||||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
@@ -667,24 +666,29 @@ async def handle_cache(
|
|||||||
cache_type=cache_type,
|
cache_type=cache_type,
|
||||||
)
|
)
|
||||||
if best_cached_response is not None:
|
if best_cached_response is not None:
|
||||||
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
||||||
return best_cached_response, None, None, None
|
return best_cached_response, None, None, None
|
||||||
else:
|
else:
|
||||||
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
||||||
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
||||||
return None, quantized, min_val, max_val
|
return None, quantized, min_val, max_val
|
||||||
|
|
||||||
# For default mode or is_embedding_cache_enabled is False, use regular cache
|
else: # handle cache for entity extraction
|
||||||
# default mode is for extract_entities or naive query
|
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
# Here is the conditions of code reaching this point:
|
||||||
|
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
|
||||||
|
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||||
else:
|
else:
|
||||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||||
if args_hash in mode_cache:
|
if args_hash in mode_cache:
|
||||||
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
||||||
return mode_cache[args_hash]["return"], None, None, None
|
return mode_cache[args_hash]["return"], None, None, None
|
||||||
|
|
||||||
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
@@ -701,9 +705,22 @@ class CacheData:
|
|||||||
|
|
||||||
|
|
||||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hashing_kv: The key-value storage for caching
|
||||||
|
cache_data: The cache data to save
|
||||||
|
"""
|
||||||
|
# Skip if storage is None or content is a streaming response
|
||||||
|
if hashing_kv is None or not cache_data.content:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If content is a streaming response, don't cache it
|
||||||
|
if hasattr(cache_data.content, "__aiter__"):
|
||||||
|
logger.debug("Streaming response detected, skipping cache")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get existing cache data
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||||
mode_cache = (
|
mode_cache = (
|
||||||
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
||||||
@@ -712,6 +729,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||||||
else:
|
else:
|
||||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||||
|
|
||||||
|
# Check if we already have identical content cached
|
||||||
|
if cache_data.args_hash in mode_cache:
|
||||||
|
existing_content = mode_cache[cache_data.args_hash].get("return")
|
||||||
|
if existing_content == cache_data.content:
|
||||||
|
logger.info(
|
||||||
|
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update cache with new content
|
||||||
mode_cache[cache_data.args_hash] = {
|
mode_cache[cache_data.args_hash] = {
|
||||||
"return": cache_data.content,
|
"return": cache_data.content,
|
||||||
"cache_type": cache_data.cache_type,
|
"cache_type": cache_data.cache_type,
|
||||||
@@ -726,6 +753,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||||||
"original_prompt": cache_data.prompt,
|
"original_prompt": cache_data.prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Only upsert if there's actual new content
|
||||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user