Merge pull request #1036 from danielaskdd/neo4j-add-min-degree

Refactoring Neo4j implementation and fixing storage init problem for Gunicorn
This commit is contained in:
zrguo
2025-03-10 22:24:22 +08:00
committed by GitHub
9 changed files with 1081 additions and 545 deletions

View File

@@ -50,9 +50,6 @@ from .auth import auth_handler
# This update allows the user to put a different.env file for each lightrag folder # This update allows the user to put a different.env file for each lightrag folder
load_dotenv(".env", override=True) load_dotenv(".env", override=True)
# Read entity extraction cache config
enable_llm_cache = os.getenv("ENABLE_LLM_CACHE_FOR_EXTRACT", "false").lower() == "true"
# Initialize config parser # Initialize config parser
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini") config.read("config.ini")
@@ -144,23 +141,25 @@ def create_app(args):
try: try:
# Initialize database connections # Initialize database connections
await rag.initialize_storages() await rag.initialize_storages()
await initialize_pipeline_status()
await initialize_pipeline_status()
pipeline_status = await get_namespace_data("pipeline_status")
should_start_autoscan = False
async with get_pipeline_status_lock():
# Auto scan documents if enabled # Auto scan documents if enabled
if args.auto_scan_at_startup: if args.auto_scan_at_startup:
# Check if a task is already running (with lock protection) if not pipeline_status.get("autoscanned", False):
pipeline_status = await get_namespace_data("pipeline_status") pipeline_status["autoscanned"] = True
should_start_task = False should_start_autoscan = True
async with get_pipeline_status_lock():
if not pipeline_status.get("busy", False): # Only run auto scan when no other process started it first
should_start_task = True if should_start_autoscan:
# Only start the task if no other task is running
if should_start_task:
# Create background task # Create background task
task = asyncio.create_task(run_scanning_process(rag, doc_manager)) task = asyncio.create_task(run_scanning_process(rag, doc_manager))
app.state.background_tasks.add(task) app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard) task.add_done_callback(app.state.background_tasks.discard)
logger.info("Auto scan task started at startup.") logger.info(f"Process {os.getpid()} auto scan task started at startup.")
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n") ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
@@ -326,7 +325,7 @@ def create_app(args):
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={ embedding_cache_config={
"enabled": True, "enabled": True,
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
@@ -355,7 +354,7 @@ def create_app(args):
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={ embedding_cache_config={
"enabled": True, "enabled": True,
"similarity_threshold": 0.95, "similarity_threshold": 0.95,
@@ -419,6 +418,7 @@ def create_app(args):
"doc_status_storage": args.doc_status_storage, "doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage, "graph_storage": args.graph_storage,
"vector_storage": args.vector_storage, "vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
}, },
"update_status": update_status, "update_status": update_status,
} }

View File

@@ -362,6 +362,11 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool
)
# Select Document loading tool (DOCLING, DEFAULT) # Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT") args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
@@ -457,8 +462,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.history_turns}") ASCIIColors.yellow(f"{args.history_turns}")
ASCIIColors.white(" ├─ Cosine Threshold: ", end="") ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
ASCIIColors.yellow(f"{args.cosine_threshold}") ASCIIColors.yellow(f"{args.cosine_threshold}")
ASCIIColors.white(" ─ Top-K: ", end="") ASCIIColors.white(" ─ Top-K: ", end="")
ASCIIColors.yellow(f"{args.top_k}") ASCIIColors.yellow(f"{args.top_k}")
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
# System Configuration # System Configuration
ASCIIColors.magenta("\n💾 Storage Configuration:") ASCIIColors.magenta("\n💾 Storage Configuration:")

View File

@@ -15,6 +15,10 @@ from lightrag.utils import (
from .shared_storage import ( from .shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_lock, get_storage_lock,
get_data_init_lock,
get_update_flag,
set_all_update_flags,
clear_all_update_flags,
try_initialize_namespace, try_initialize_namespace,
) )
@@ -27,20 +31,24 @@ class JsonDocStatusStorage(DocStatusStorage):
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock()
self._data = None self._data = None
self._storage_lock = None
self.storage_updated = None
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
self._storage_lock = get_storage_lock()
self.storage_updated = await get_update_flag(self.namespace)
async with get_data_init_lock():
# check need_init must before get_namespace_data # check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace) need_init = await try_initialize_namespace(self.namespace)
self._data = await get_namespace_data(self.namespace) self._data = await get_namespace_data(self.namespace)
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
async with self._storage_lock: async with self._storage_lock:
self._data.update(loaded_data) self._data.update(loaded_data)
logger.info( logger.info(
f"Loaded document status storage with {len(loaded_data)} records" f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
) )
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
@@ -87,18 +95,24 @@ class JsonDocStatusStorage(DocStatusStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
async with self._storage_lock: async with self._storage_lock:
if self.storage_updated.value:
data_dict = ( data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
logger.info(
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
)
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
await clear_all_update_flags(self.namespace)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
logger.info(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace)
await self.index_done_callback() await self.index_done_callback()
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
@@ -109,9 +123,12 @@ class JsonDocStatusStorage(DocStatusStorage):
async with self._storage_lock: async with self._storage_lock:
for doc_id in doc_ids: for doc_id in doc_ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await set_all_update_flags(self.namespace)
await self.index_done_callback() await self.index_done_callback()
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
async with self._storage_lock: async with self._storage_lock:
self._data.clear() self._data.clear()
await set_all_update_flags(self.namespace)
await self.index_done_callback()

View File

@@ -13,6 +13,10 @@ from lightrag.utils import (
from .shared_storage import ( from .shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_lock, get_storage_lock,
get_data_init_lock,
get_update_flag,
set_all_update_flags,
clear_all_update_flags,
try_initialize_namespace, try_initialize_namespace,
) )
@@ -23,26 +27,63 @@ class JsonKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock()
self._data = None self._data = None
self._storage_lock = None
self.storage_updated = None
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
self._storage_lock = get_storage_lock()
self.storage_updated = await get_update_flag(self.namespace)
async with get_data_init_lock():
# check need_init must before get_namespace_data # check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace) need_init = await try_initialize_namespace(self.namespace)
self._data = await get_namespace_data(self.namespace) self._data = await get_namespace_data(self.namespace)
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
async with self._storage_lock: async with self._storage_lock:
self._data.update(loaded_data) self._data.update(loaded_data)
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
# Calculate data count based on namespace
if self.namespace.endswith("cache"):
# For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in loaded_data.values()
if isinstance(first_level_dict, dict)
)
else:
# For non-cache namespaces, use the original count method
data_count = len(loaded_data)
logger.info(
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
)
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
async with self._storage_lock: async with self._storage_lock:
if self.storage_updated.value:
data_dict = ( data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
# Calculate data count based on namespace
if self.namespace.endswith("cache"):
# # For cache namespaces, sum the cache entries across all cache types
data_count = sum(
len(first_level_dict)
for first_level_dict in data_dict.values()
if isinstance(first_level_dict, dict)
)
else:
# For non-cache namespaces, use the original count method
data_count = len(data_dict)
logger.info(
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
)
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
await clear_all_update_flags(self.namespace)
async def get_all(self) -> dict[str, Any]: async def get_all(self) -> dict[str, Any]:
"""Get all data from storage """Get all data from storage
@@ -73,15 +114,16 @@ class JsonKVStorage(BaseKVStorage):
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
logger.info(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock: async with self._storage_lock:
left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(data)
self._data.update(left_data) await set_all_update_flags(self.namespace)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
async with self._storage_lock: async with self._storage_lock:
for doc_id in ids: for doc_id in ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await set_all_update_flags(self.namespace)
await self.index_done_callback() await self.index_done_callback()

File diff suppressed because it is too large Load Diff

View File

@@ -7,11 +7,17 @@ from typing import Any, Dict, Optional, Union, TypeVar, Generic
# Define a direct print function for critical logs that must be visible in all processes # Define a direct print function for critical logs that must be visible in all processes
def direct_log(message, level="INFO"): def direct_log(message, level="INFO", enable_output: bool = True):
""" """
Log a message directly to stderr to ensure visibility in all processes, Log a message directly to stderr to ensure visibility in all processes,
including the Gunicorn master process. including the Gunicorn master process.
Args:
message: The message to log
level: Log level (default: "INFO")
enable_output: Whether to actually output the log (default: True)
""" """
if enable_output:
print(f"{level}: {message}", file=sys.stderr, flush=True) print(f"{level}: {message}", file=sys.stderr, flush=True)
@@ -32,55 +38,165 @@ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
_storage_lock: Optional[LockType] = None _storage_lock: Optional[LockType] = None
_internal_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None
_graph_db_lock: Optional[LockType] = None
_data_init_lock: Optional[LockType] = None
class UnifiedLock(Generic[T]): class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): def __init__(
self,
lock: Union[ProcessLock, asyncio.Lock],
is_async: bool,
name: str = "unnamed",
enable_logging: bool = True,
):
self._lock = lock self._lock = lock
self._is_async = is_async self._is_async = is_async
self._pid = os.getpid() # for debug only
self._name = name # for debug only
self._enable_logging = enable_logging # for debug only
async def __aenter__(self) -> "UnifiedLock[T]": async def __aenter__(self) -> "UnifiedLock[T]":
try:
direct_log(
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
if self._is_async: if self._is_async:
await self._lock.acquire() await self._lock.acquire()
else: else:
self._lock.acquire() self._lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
enable_output=self._enable_logging,
)
return self return self
except Exception as e:
direct_log(
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
direct_log(
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
if self._is_async: if self._is_async:
self._lock.release() self._lock.release()
else: else:
self._lock.release() self._lock.release()
direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
enable_output=self._enable_logging,
)
except Exception as e:
direct_log(
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise
def __enter__(self) -> "UnifiedLock[T]": def __enter__(self) -> "UnifiedLock[T]":
"""For backward compatibility""" """For backward compatibility"""
try:
if self._is_async: if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock") raise RuntimeError("Use 'async with' for shared_storage lock")
direct_log(
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
enable_output=self._enable_logging,
)
self._lock.acquire() self._lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
enable_output=self._enable_logging,
)
return self return self
except Exception as e:
direct_log(
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""For backward compatibility""" """For backward compatibility"""
try:
if self._is_async: if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock") raise RuntimeError("Use 'async with' for shared_storage lock")
direct_log(
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
enable_output=self._enable_logging,
)
self._lock.release() self._lock.release()
direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
enable_output=self._enable_logging,
)
except Exception as e:
direct_log(
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise
def get_internal_lock() -> UnifiedLock: def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess) return UnifiedLock(
lock=_internal_lock,
is_async=not is_multiprocess,
name="internal_lock",
enable_logging=enable_logging,
)
def get_storage_lock() -> UnifiedLock: def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess) return UnifiedLock(
lock=_storage_lock,
is_async=not is_multiprocess,
name="storage_lock",
enable_logging=enable_logging,
)
def get_pipeline_status_lock() -> UnifiedLock: def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess) return UnifiedLock(
lock=_pipeline_status_lock,
is_async=not is_multiprocess,
name="pipeline_status_lock",
enable_logging=enable_logging,
)
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified graph database lock for ensuring atomic operations"""
return UnifiedLock(
lock=_graph_db_lock,
is_async=not is_multiprocess,
name="graph_db_lock",
enable_logging=enable_logging,
)
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
return UnifiedLock(
lock=_data_init_lock,
is_async=not is_multiprocess,
name="data_init_lock",
enable_logging=enable_logging,
)
def initialize_share_data(workers: int = 1): def initialize_share_data(workers: int = 1):
@@ -108,6 +224,8 @@ def initialize_share_data(workers: int = 1):
_storage_lock, \ _storage_lock, \
_internal_lock, \ _internal_lock, \
_pipeline_status_lock, \ _pipeline_status_lock, \
_graph_db_lock, \
_data_init_lock, \
_shared_dicts, \ _shared_dicts, \
_init_flags, \ _init_flags, \
_initialized, \ _initialized, \
@@ -120,14 +238,16 @@ def initialize_share_data(workers: int = 1):
) )
return return
_manager = Manager()
_workers = workers _workers = workers
if workers > 1: if workers > 1:
is_multiprocess = True is_multiprocess = True
_manager = Manager()
_internal_lock = _manager.Lock() _internal_lock = _manager.Lock()
_storage_lock = _manager.Lock() _storage_lock = _manager.Lock()
_pipeline_status_lock = _manager.Lock() _pipeline_status_lock = _manager.Lock()
_graph_db_lock = _manager.Lock()
_data_init_lock = _manager.Lock()
_shared_dicts = _manager.dict() _shared_dicts = _manager.dict()
_init_flags = _manager.dict() _init_flags = _manager.dict()
_update_flags = _manager.dict() _update_flags = _manager.dict()
@@ -139,6 +259,8 @@ def initialize_share_data(workers: int = 1):
_internal_lock = asyncio.Lock() _internal_lock = asyncio.Lock()
_storage_lock = asyncio.Lock() _storage_lock = asyncio.Lock()
_pipeline_status_lock = asyncio.Lock() _pipeline_status_lock = asyncio.Lock()
_graph_db_lock = asyncio.Lock()
_data_init_lock = asyncio.Lock()
_shared_dicts = {} _shared_dicts = {}
_init_flags = {} _init_flags = {}
_update_flags = {} _update_flags = {}
@@ -164,6 +286,7 @@ async def initialize_pipeline_status():
history_messages = _manager.list() if is_multiprocess else [] history_messages = _manager.list() if is_multiprocess else []
pipeline_namespace.update( pipeline_namespace.update(
{ {
"autoscanned": False, # Auto-scan started
"busy": False, # Control concurrent processes "busy": False, # Control concurrent processes
"job_name": "Default Job", # Current job name (indexing files/indexing texts) "job_name": "Default Job", # Current job name (indexing files/indexing texts)
"job_start": None, # Job start time "job_start": None, # Job start time
@@ -200,7 +323,12 @@ async def get_update_flag(namespace: str):
if is_multiprocess and _manager is not None: if is_multiprocess and _manager is not None:
new_update_flag = _manager.Value("b", False) new_update_flag = _manager.Value("b", False)
else: else:
new_update_flag = False # Create a simple mutable object to store boolean value for compatibility with mutiprocess
class MutableBoolean:
def __init__(self, initial_value=False):
self.value = initial_value
new_update_flag = MutableBoolean(False)
_update_flags[namespace].append(new_update_flag) _update_flags[namespace].append(new_update_flag)
return new_update_flag return new_update_flag
@@ -220,7 +348,26 @@ async def set_all_update_flags(namespace: str):
if is_multiprocess: if is_multiprocess:
_update_flags[namespace][i].value = True _update_flags[namespace][i].value = True
else: else:
_update_flags[namespace][i] = True # Use .value attribute instead of direct assignment
_update_flags[namespace][i].value = True
async def clear_all_update_flags(namespace: str):
"""Clear all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags
if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized")
async with get_internal_lock():
if namespace not in _update_flags:
raise ValueError(f"Namespace {namespace} not found in update flags")
# Update flags for both modes
for i in range(len(_update_flags[namespace])):
if is_multiprocess:
_update_flags[namespace][i].value = False
else:
# Use .value attribute instead of direct assignment
_update_flags[namespace][i].value = False
async def get_all_update_flags_status() -> Dict[str, list]: async def get_all_update_flags_status() -> Dict[str, list]:
@@ -247,7 +394,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
return result return result
def try_initialize_namespace(namespace: str) -> bool: async def try_initialize_namespace(namespace: str) -> bool:
""" """
Returns True if the current worker(process) gets initialization permission for loading data later. Returns True if the current worker(process) gets initialization permission for loading data later.
The worker does not get the permission is prohibited to load data from files. The worker does not get the permission is prohibited to load data from files.
@@ -257,6 +404,7 @@ def try_initialize_namespace(namespace: str) -> bool:
if _init_flags is None: if _init_flags is None:
raise ValueError("Try to create nanmespace before Shared-Data is initialized") raise ValueError("Try to create nanmespace before Shared-Data is initialized")
async with get_internal_lock():
if namespace not in _init_flags: if namespace not in _init_flags:
_init_flags[namespace] = True _init_flags[namespace] = True
direct_log( direct_log(
@@ -266,6 +414,7 @@ def try_initialize_namespace(namespace: str) -> bool:
direct_log( direct_log(
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
) )
return False return False
@@ -304,6 +453,8 @@ def finalize_share_data():
_storage_lock, \ _storage_lock, \
_internal_lock, \ _internal_lock, \
_pipeline_status_lock, \ _pipeline_status_lock, \
_graph_db_lock, \
_data_init_lock, \
_shared_dicts, \ _shared_dicts, \
_init_flags, \ _init_flags, \
_initialized, \ _initialized, \
@@ -369,6 +520,8 @@ def finalize_share_data():
_storage_lock = None _storage_lock = None
_internal_lock = None _internal_lock = None
_pipeline_status_lock = None _pipeline_status_lock = None
_graph_db_lock = None
_data_init_lock = None
_update_flags = None _update_flags = None
direct_log(f"Process {os.getpid()} storage data finalization complete") direct_log(f"Process {os.getpid()} storage data finalization complete")

View File

@@ -354,6 +354,9 @@ class LightRAG:
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
global_config=asdict(
self
), # Add global_config to ensure cache works properly
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
@@ -404,18 +407,8 @@ class LightRAG:
embedding_func=None, embedding_func=None,
) )
if self.llm_response_cache and hasattr( # Directly use llm_response_cache, don't create a new object
self.llm_response_cache, "global_config"
):
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
@@ -590,6 +583,7 @@ class LightRAG:
split_by_character, split_by_character_only split_by_character, split_by_character_only
) )
# TODO: deprecated, use insert instead
def insert_custom_chunks( def insert_custom_chunks(
self, self,
full_text: str, full_text: str,
@@ -601,6 +595,7 @@ class LightRAG:
self.ainsert_custom_chunks(full_text, text_chunks, doc_id) self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
) )
# TODO: deprecated, use ainsert instead
async def ainsert_custom_chunks( async def ainsert_custom_chunks(
self, full_text: str, text_chunks: list[str], doc_id: str | None = None self, full_text: str, text_chunks: list[str], doc_id: str | None = None
) -> None: ) -> None:
@@ -892,7 +887,9 @@ class LightRAG:
self.chunks_vdb.upsert(chunks) self.chunks_vdb.upsert(chunks)
) )
entity_relation_task = asyncio.create_task( entity_relation_task = asyncio.create_task(
self._process_entity_relation_graph(chunks) self._process_entity_relation_graph(
chunks, pipeline_status, pipeline_status_lock
)
) )
full_docs_task = asyncio.create_task( full_docs_task = asyncio.create_task(
self.full_docs.upsert( self.full_docs.upsert(
@@ -1007,21 +1004,27 @@ class LightRAG:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: async def _process_entity_relation_graph(
self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
) -> None:
try: try:
await extract_entities( await extract_entities(
chunk, chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb, relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self), global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
) )
except Exception as e: except Exception as e:
logger.error("Failed to extract entities and relationships") logger.error("Failed to extract entities and relationships")
raise e raise e
async def _insert_done(self) -> None: async def _insert_done(
self, pipeline_status=None, pipeline_status_lock=None
) -> None:
tasks = [ tasks = [
cast(StorageNameSpace, storage_inst).index_done_callback() cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore for storage_inst in [ # type: ignore
@@ -1040,10 +1043,8 @@ class LightRAG:
log_message = "All Insert done" log_message = "All Insert done"
logger.info(log_message) logger.info(log_message)
# 获取 pipeline_status 并更新 latest_message 和 history_messages if pipeline_status is not None and pipeline_status_lock is not None:
from lightrag.kg.shared_storage import get_namespace_data async with pipeline_status_lock:
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
@@ -1260,16 +1261,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt, system_prompt=system_prompt,
) )
elif param.mode == "naive": elif param.mode == "naive":
@@ -1279,16 +1271,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt, system_prompt=system_prompt,
) )
elif param.mode == "mix": elif param.mode == "mix":
@@ -1301,16 +1284,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
system_prompt=system_prompt, system_prompt=system_prompt,
) )
else: else:
@@ -1344,14 +1318,7 @@ class LightRAG:
text=query, text=query,
param=param, param=param,
global_config=asdict(self), global_config=asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
or self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
param.hl_keywords = hl_keywords param.hl_keywords = hl_keywords
@@ -1375,16 +1342,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
elif param.mode == "naive": elif param.mode == "naive":
response = await naive_query( response = await naive_query(
@@ -1393,16 +1351,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
elif param.mode == "mix": elif param.mode == "mix":
response = await mix_kg_vector_query( response = await mix_kg_vector_query(
@@ -1414,16 +1363,7 @@ class LightRAG:
self.text_chunks, self.text_chunks,
param, param,
asdict(self), asdict(self),
hashing_kv=self.llm_response_cache hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
),
global_config=asdict(self),
embedding_func=self.embedding_func,
),
) )
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import re import re
import os
from typing import Any, AsyncIterator from typing import Any, AsyncIterator
from collections import Counter, defaultdict from collections import Counter, defaultdict
@@ -220,6 +221,7 @@ async def _merge_nodes_then_upsert(
entity_name, description, global_config entity_name, description, global_config
) )
node_data = dict( node_data = dict(
entity_id=entity_name,
entity_type=entity_type, entity_type=entity_type,
description=description, description=description,
source_id=source_id, source_id=source_id,
@@ -301,6 +303,7 @@ async def _merge_edges_then_upsert(
await knowledge_graph_inst.upsert_node( await knowledge_graph_inst.upsert_node(
need_insert_id, need_insert_id,
node_data={ node_data={
"entity_id": need_insert_id,
"source_id": source_id, "source_id": source_id,
"description": description, "description": description,
"entity_type": "UNKNOWN", "entity_type": "UNKNOWN",
@@ -337,11 +340,10 @@ async def extract_entities(
entity_vdb: BaseVectorStorage, entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
global_config: dict[str, str], global_config: dict[str, str],
pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> None: ) -> None:
from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = await get_namespace_data("pipeline_status")
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[ enable_llm_cache_for_entity_extract: bool = global_config[
@@ -400,6 +402,7 @@ async def extract_entities(
else: else:
_prompt = input_text _prompt = input_text
# TODO add cache_type="extract"
arg_hash = compute_args_hash(_prompt) arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache( cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, llm_response_cache,
@@ -407,7 +410,6 @@ async def extract_entities(
_prompt, _prompt,
"default", "default",
cache_type="extract", cache_type="extract",
force_llm_cache=True,
) )
if cached_return: if cached_return:
logger.debug(f"Found cache for {arg_hash}") logger.debug(f"Found cache for {arg_hash}")
@@ -504,6 +506,8 @@ async def extract_entities(
relations_count = len(maybe_edges) relations_count = len(maybe_edges)
log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
logger.info(log_message) logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)
@@ -519,6 +523,12 @@ async def extract_entities(
for k, v in m_edges.items(): for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
# Ensure that nodes and edges are merged and upserted atomically
async with graph_db_lock:
all_entities_data = await asyncio.gather( all_entities_data = await asyncio.gather(
*[ *[
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
@@ -528,7 +538,9 @@ async def extract_entities(
all_relationships_data = await asyncio.gather( all_relationships_data = await asyncio.gather(
*[ *[
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) _merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
)
for k, v in maybe_edges.items() for k, v in maybe_edges.items()
] ]
) )
@@ -536,6 +548,8 @@ async def extract_entities(
if not (all_entities_data or all_relationships_data): if not (all_entities_data or all_relationships_data):
log_message = "Didn't extract any entities and relationships." log_message = "Didn't extract any entities and relationships."
logger.info(log_message) logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
return return
@@ -543,16 +557,22 @@ async def extract_entities(
if not all_entities_data: if not all_entities_data:
log_message = "Didn't extract any entities" log_message = "Didn't extract any entities"
logger.info(log_message) logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
if not all_relationships_data: if not all_relationships_data:
log_message = "Didn't extract any relationships" log_message = "Didn't extract any relationships"
logger.info(log_message) logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
logger.info(log_message) logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message) pipeline_status["history_messages"].append(log_message)
verbose_debug( verbose_debug(
@@ -1017,6 +1037,7 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
): ):
logger.info(f"Process {os.getpid()} buidling query context...")
if query_param.mode == "local": if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data( entities_context, relations_context, text_units_context = await _get_node_data(
ll_keywords, ll_keywords,

View File

@@ -633,15 +633,15 @@ async def handle_cache(
prompt, prompt,
mode="default", mode="default",
cache_type=None, cache_type=None,
force_llm_cache=False,
): ):
"""Generic cache handling function""" """Generic cache handling function"""
if hashing_kv is None or not ( if hashing_kv is None:
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache") return None, None, None, None
):
if mode != "default": # handle cache for all type of query
if not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None return None, None, None, None
if mode != "default":
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", "embedding_cache_config",
@@ -651,8 +651,7 @@ async def handle_cache(
use_llm_check = embedding_cache_config.get("use_llm_check", False) use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None quantized = min_val = max_val = None
if is_embedding_cache_enabled: if is_embedding_cache_enabled: # Use embedding simularity to match cache
# Use embedding cache
current_embedding = await hashing_kv.embedding_func([prompt]) current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = hashing_kv.global_config.get("llm_model_func") llm_model_func = hashing_kv.global_config.get("llm_model_func")
quantized, min_val, max_val = quantize_embedding(current_embedding[0]) quantized, min_val, max_val = quantize_embedding(current_embedding[0])
@@ -667,24 +666,29 @@ async def handle_cache(
cache_type=cache_type, cache_type=cache_type,
) )
if best_cached_response is not None: if best_cached_response is not None:
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})") logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})")
return best_cached_response, None, None, None return best_cached_response, None, None, None
else: else:
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter # if caching keyword embedding is enabled, return the quantized embedding for saving it latter
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})") logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
return None, quantized, min_val, max_val return None, quantized, min_val, max_val
# For default mode or is_embedding_cache_enabled is False, use regular cache else: # handle cache for entity extraction
# default mode is for extract_entities or naive query if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
return None, None, None, None
# Here is the conditions of code reaching this point:
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else: else:
mode_cache = await hashing_kv.get_by_id(mode) or {} mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache: if args_hash in mode_cache:
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
return mode_cache[args_hash]["return"], None, None, None return mode_cache[args_hash]["return"], None, None, None
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
return None, None, None, None return None, None, None, None
@@ -701,9 +705,22 @@ class CacheData:
async def save_to_cache(hashing_kv, cache_data: CacheData): async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"): """Save data to cache, with improved handling for streaming responses and duplicate content.
Args:
hashing_kv: The key-value storage for caching
cache_data: The cache data to save
"""
# Skip if storage is None or content is a streaming response
if hashing_kv is None or not cache_data.content:
return return
# If content is a streaming response, don't cache it
if hasattr(cache_data.content, "__aiter__"):
logger.debug("Streaming response detected, skipping cache")
return
# Get existing cache data
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = ( mode_cache = (
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash) await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
@@ -712,6 +729,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
else: else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
# Check if we already have identical content cached
if cache_data.args_hash in mode_cache:
existing_content = mode_cache[cache_data.args_hash].get("return")
if existing_content == cache_data.content:
logger.info(
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
)
return
# Update cache with new content
mode_cache[cache_data.args_hash] = { mode_cache[cache_data.args_hash] = {
"return": cache_data.content, "return": cache_data.content,
"cache_type": cache_data.cache_type, "cache_type": cache_data.cache_type,
@@ -726,6 +753,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
"original_prompt": cache_data.prompt, "original_prompt": cache_data.prompt,
} }
# Only upsert if there's actual new content
await hashing_kv.upsert({cache_data.mode: mode_cache}) await hashing_kv.upsert({cache_data.mode: mode_cache})