diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 0b660d68..d0841eec 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -181,10 +181,10 @@ class Neo4JStorage(BaseGraphStorage): Args: label: The label to validate - + Returns: str: The cleaned label - + Raises: ValueError: If label is empty after cleaning """ @@ -283,7 +283,9 @@ class Neo4JStorage(BaseGraphStorage): query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n" result = await session.run(query, entity_id=entity_name_label) try: - records = await result.fetch(2) # Get 2 records for duplication check + records = await result.fetch( + 2 + ) # Get 2 records for duplication check if len(records) > 1: logger.warning( @@ -552,6 +554,7 @@ class Neo4JStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): query = f""" MERGE (n:`{label}` {{entity_id: $properties.entity_id}}) @@ -562,7 +565,7 @@ class Neo4JStorage(BaseGraphStorage): f"Upserted node with label '{label}' and properties: {properties}" ) await result.consume() # Ensure result is fully consumed - + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") @@ -602,18 +605,26 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) try: - records = await result.fetch(2) # We only need to know if there are 0, 1, or >1 nodes - + records = await result.fetch( + 2 + ) # We only need to know if there are 0, 1, or >1 nodes + if not records or records[0]["node_count"] == 0: - raise ValueError(f"Neo4j: node with label '{node_label}' does not exist") - + raise ValueError( + f"Neo4j: node with label '{node_label}' does not exist" + ) + if records[0]["node_count"] > 1: - raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node") - + raise ValueError( + f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node" + ) + node = records[0]["n"] if "entity_id" not in node: - raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property") - + raise ValueError( + f"Neo4j: node with label '{node_label}' does not have an entity_id property" + ) + return node["entity_id"] finally: await result.consume() # Ensure result is fully consumed @@ -656,6 +667,7 @@ class Neo4JStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): query = f""" MATCH (source:`{source_label}` {{entity_id: $source_entity_id}}) @@ -666,10 +678,10 @@ class Neo4JStorage(BaseGraphStorage): RETURN r, source, target """ result = await tx.run( - query, + query, source_entity_id=source_entity_id, target_entity_id=target_entity_id, - properties=edge_properties + properties=edge_properties, ) try: records = await result.fetch(100) @@ -681,7 +693,7 @@ class Neo4JStorage(BaseGraphStorage): ) finally: await result.consume() # Ensure result is consumed - + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") @@ -891,7 +903,9 @@ class Neo4JStorage(BaseGraphStorage): results = await session.run(query, {"node_id": node.id}) # Get all records and release database connection - records = await results.fetch(1000) # Max neighbour nodes we can handled + records = await results.fetch( + 1000 + ) # Max neighbour nodes we can handled await results.consume() # Ensure results are consumed # Nodes not connected to start node need to check degree diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 67206971..9ccb2a99 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -11,7 +11,7 @@ def direct_log(message, level="INFO", enable_output: bool = True): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. - + Args: message: The message to log level: Log level (default: "INFO") @@ -44,7 +44,13 @@ _graph_db_lock: Optional[LockType] = None class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" - def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool, name: str = "unnamed", enable_logging: bool = True): + def __init__( + self, + lock: Union[ProcessLock, asyncio.Lock], + is_async: bool, + name: str = "unnamed", + enable_logging: bool = True, + ): self._lock = lock self._is_async = is_async self._pid = os.getpid() # for debug only @@ -53,27 +59,47 @@ class UnifiedLock(Generic[T]): async def __aenter__(self) -> "UnifiedLock[T]": try: - direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", + enable_output=self._enable_logging, + ) if self._is_async: await self._lock.acquire() else: self._lock.acquire() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", + enable_output=self._enable_logging, + ) return self except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise async def __aexit__(self, exc_type, exc_val, exc_tb): try: - direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", + enable_output=self._enable_logging, + ) if self._is_async: self._lock.release() else: self._lock.release() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", + enable_output=self._enable_logging, + ) except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise def __enter__(self) -> "UnifiedLock[T]": @@ -81,12 +107,22 @@ class UnifiedLock(Generic[T]): try: if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") - direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", + enable_output=self._enable_logging, + ) self._lock.acquire() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", + enable_output=self._enable_logging, + ) return self except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise def __exit__(self, exc_type, exc_val, exc_tb): @@ -94,32 +130,62 @@ class UnifiedLock(Generic[T]): try: if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") - direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", + enable_output=self._enable_logging, + ) self._lock.release() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", + enable_output=self._enable_logging, + ) except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess, name="internal_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_internal_lock, + is_async=not is_multiprocess, + name="internal_lock", + enable_logging=enable_logging, + ) def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess, name="storage_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_storage_lock, + is_async=not is_multiprocess, + name="storage_lock", + enable_logging=enable_logging, + ) def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_pipeline_status_lock, + is_async=not is_multiprocess, + name="pipeline_status_lock", + enable_logging=enable_logging, + ) def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: """return unified graph database lock for ensuring atomic operations""" - return UnifiedLock(lock=_graph_db_lock, is_async=not is_multiprocess, name="graph_db_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_graph_db_lock, + is_async=not is_multiprocess, + name="graph_db_lock", + enable_logging=enable_logging, + ) def initialize_share_data(workers: int = 1): diff --git a/lightrag/operate.py b/lightrag/operate.py index fb7b27a0..6c1bfd05 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -522,8 +522,9 @@ async def extract_entities( maybe_edges[tuple(sorted(k))].extend(v) from .kg.shared_storage import get_graph_db_lock - graph_db_lock = get_graph_db_lock(enable_logging = True) - + + graph_db_lock = get_graph_db_lock(enable_logging=True) + # Ensure that nodes and edges are merged and upserted atomically async with graph_db_lock: all_entities_data = await asyncio.gather( @@ -535,7 +536,9 @@ async def extract_entities( all_relationships_data = await asyncio.gather( *[ - _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) + _merge_edges_then_upsert( + k[0], k[1], v, knowledge_graph_inst, global_config + ) for k, v in maybe_edges.items() ] )