Fix linting

This commit is contained in:
yangdx
2025-03-09 01:00:42 +08:00
parent 3cf4268e7a
commit c5d0962872
3 changed files with 120 additions and 37 deletions

View File

@@ -181,10 +181,10 @@ class Neo4JStorage(BaseGraphStorage):
Args: Args:
label: The label to validate label: The label to validate
Returns: Returns:
str: The cleaned label str: The cleaned label
Raises: Raises:
ValueError: If label is empty after cleaning ValueError: If label is empty after cleaning
""" """
@@ -283,7 +283,9 @@ class Neo4JStorage(BaseGraphStorage):
query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n" query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
result = await session.run(query, entity_id=entity_name_label) result = await session.run(query, entity_id=entity_name_label)
try: try:
records = await result.fetch(2) # Get 2 records for duplication check records = await result.fetch(
2
) # Get 2 records for duplication check
if len(records) > 1: if len(records) > 1:
logger.warning( logger.warning(
@@ -552,6 +554,7 @@ class Neo4JStorage(BaseGraphStorage):
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = f""" query = f"""
MERGE (n:`{label}` {{entity_id: $properties.entity_id}}) MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
@@ -562,7 +565,7 @@ class Neo4JStorage(BaseGraphStorage):
f"Upserted node with label '{label}' and properties: {properties}" f"Upserted node with label '{label}' and properties: {properties}"
) )
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert) await session.execute_write(execute_upsert)
except Exception as e: except Exception as e:
logger.error(f"Error during upsert: {str(e)}") logger.error(f"Error during upsert: {str(e)}")
@@ -602,18 +605,26 @@ class Neo4JStorage(BaseGraphStorage):
""" """
result = await session.run(query) result = await session.run(query)
try: try:
records = await result.fetch(2) # We only need to know if there are 0, 1, or >1 nodes records = await result.fetch(
2
) # We only need to know if there are 0, 1, or >1 nodes
if not records or records[0]["node_count"] == 0: if not records or records[0]["node_count"] == 0:
raise ValueError(f"Neo4j: node with label '{node_label}' does not exist") raise ValueError(
f"Neo4j: node with label '{node_label}' does not exist"
)
if records[0]["node_count"] > 1: if records[0]["node_count"] > 1:
raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node") raise ValueError(
f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node"
)
node = records[0]["n"] node = records[0]["n"]
if "entity_id" not in node: if "entity_id" not in node:
raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property") raise ValueError(
f"Neo4j: node with label '{node_label}' does not have an entity_id property"
)
return node["entity_id"] return node["entity_id"]
finally: finally:
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
@@ -656,6 +667,7 @@ class Neo4JStorage(BaseGraphStorage):
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
query = f""" query = f"""
MATCH (source:`{source_label}` {{entity_id: $source_entity_id}}) MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
@@ -666,10 +678,10 @@ class Neo4JStorage(BaseGraphStorage):
RETURN r, source, target RETURN r, source, target
""" """
result = await tx.run( result = await tx.run(
query, query,
source_entity_id=source_entity_id, source_entity_id=source_entity_id,
target_entity_id=target_entity_id, target_entity_id=target_entity_id,
properties=edge_properties properties=edge_properties,
) )
try: try:
records = await result.fetch(100) records = await result.fetch(100)
@@ -681,7 +693,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
finally: finally:
await result.consume() # Ensure result is consumed await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert) await session.execute_write(execute_upsert)
except Exception as e: except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}") logger.error(f"Error during edge upsert: {str(e)}")
@@ -891,7 +903,9 @@ class Neo4JStorage(BaseGraphStorage):
results = await session.run(query, {"node_id": node.id}) results = await session.run(query, {"node_id": node.id})
# Get all records and release database connection # Get all records and release database connection
records = await results.fetch(1000) # Max neighbour nodes we can handled records = await results.fetch(
1000
) # Max neighbour nodes we can handled
await results.consume() # Ensure results are consumed await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree # Nodes not connected to start node need to check degree

View File

@@ -11,7 +11,7 @@ def direct_log(message, level="INFO", enable_output: bool = True):
""" """
Log a message directly to stderr to ensure visibility in all processes, Log a message directly to stderr to ensure visibility in all processes,
including the Gunicorn master process. including the Gunicorn master process.
Args: Args:
message: The message to log message: The message to log
level: Log level (default: "INFO") level: Log level (default: "INFO")
@@ -44,7 +44,13 @@ _graph_db_lock: Optional[LockType] = None
class UnifiedLock(Generic[T]): class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool, name: str = "unnamed", enable_logging: bool = True): def __init__(
self,
lock: Union[ProcessLock, asyncio.Lock],
is_async: bool,
name: str = "unnamed",
enable_logging: bool = True,
):
self._lock = lock self._lock = lock
self._is_async = is_async self._is_async = is_async
self._pid = os.getpid() # for debug only self._pid = os.getpid() # for debug only
@@ -53,27 +59,47 @@ class UnifiedLock(Generic[T]):
async def __aenter__(self) -> "UnifiedLock[T]": async def __aenter__(self) -> "UnifiedLock[T]":
try: try:
direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
if self._is_async: if self._is_async:
await self._lock.acquire() await self._lock.acquire()
else: else:
self._lock.acquire() self._lock.acquire()
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
enable_output=self._enable_logging,
)
return self return self
except Exception as e: except Exception as e:
direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise raise
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
try: try:
direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
if self._is_async: if self._is_async:
self._lock.release() self._lock.release()
else: else:
self._lock.release() self._lock.release()
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
enable_output=self._enable_logging,
)
except Exception as e: except Exception as e:
direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise raise
def __enter__(self) -> "UnifiedLock[T]": def __enter__(self) -> "UnifiedLock[T]":
@@ -81,12 +107,22 @@ class UnifiedLock(Generic[T]):
try: try:
if self._is_async: if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock") raise RuntimeError("Use 'async with' for shared_storage lock")
direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
enable_output=self._enable_logging,
)
self._lock.acquire() self._lock.acquire()
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
enable_output=self._enable_logging,
)
return self return self
except Exception as e: except Exception as e:
direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise raise
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
@@ -94,32 +130,62 @@ class UnifiedLock(Generic[T]):
try: try:
if self._is_async: if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock") raise RuntimeError("Use 'async with' for shared_storage lock")
direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
enable_output=self._enable_logging,
)
self._lock.release() self._lock.release()
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
enable_output=self._enable_logging,
)
except Exception as e: except Exception as e:
direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) direct_log(
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
level="ERROR",
enable_output=self._enable_logging,
)
raise raise
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess, name="internal_lock", enable_logging=enable_logging) return UnifiedLock(
lock=_internal_lock,
is_async=not is_multiprocess,
name="internal_lock",
enable_logging=enable_logging,
)
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess, name="storage_lock", enable_logging=enable_logging) return UnifiedLock(
lock=_storage_lock,
is_async=not is_multiprocess,
name="storage_lock",
enable_logging=enable_logging,
)
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency""" """return unified storage lock for data consistency"""
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging) return UnifiedLock(
lock=_pipeline_status_lock,
is_async=not is_multiprocess,
name="pipeline_status_lock",
enable_logging=enable_logging,
)
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified graph database lock for ensuring atomic operations""" """return unified graph database lock for ensuring atomic operations"""
return UnifiedLock(lock=_graph_db_lock, is_async=not is_multiprocess, name="graph_db_lock", enable_logging=enable_logging) return UnifiedLock(
lock=_graph_db_lock,
is_async=not is_multiprocess,
name="graph_db_lock",
enable_logging=enable_logging,
)
def initialize_share_data(workers: int = 1): def initialize_share_data(workers: int = 1):

View File

@@ -522,8 +522,9 @@ async def extract_entities(
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
from .kg.shared_storage import get_graph_db_lock from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging = True)
graph_db_lock = get_graph_db_lock(enable_logging=True)
# Ensure that nodes and edges are merged and upserted atomically # Ensure that nodes and edges are merged and upserted atomically
async with graph_db_lock: async with graph_db_lock:
all_entities_data = await asyncio.gather( all_entities_data = await asyncio.gather(
@@ -535,7 +536,9 @@ async def extract_entities(
all_relationships_data = await asyncio.gather( all_relationships_data = await asyncio.gather(
*[ *[
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) _merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
)
for k, v in maybe_edges.items() for k, v in maybe_edges.items()
] ]
) )