Fix linting
This commit is contained in:
@@ -181,10 +181,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
label: The label to validate
|
label: The label to validate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The cleaned label
|
str: The cleaned label
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If label is empty after cleaning
|
ValueError: If label is empty after cleaning
|
||||||
"""
|
"""
|
||||||
@@ -283,7 +283,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
|
query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
|
||||||
result = await session.run(query, entity_id=entity_name_label)
|
result = await session.run(query, entity_id=entity_name_label)
|
||||||
try:
|
try:
|
||||||
records = await result.fetch(2) # Get 2 records for duplication check
|
records = await result.fetch(
|
||||||
|
2
|
||||||
|
) # Get 2 records for duplication check
|
||||||
|
|
||||||
if len(records) > 1:
|
if len(records) > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -552,6 +554,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
|
||||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||||
query = f"""
|
query = f"""
|
||||||
MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
|
MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
|
||||||
@@ -562,7 +565,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
f"Upserted node with label '{label}' and properties: {properties}"
|
f"Upserted node with label '{label}' and properties: {properties}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure result is fully consumed
|
await result.consume() # Ensure result is fully consumed
|
||||||
|
|
||||||
await session.execute_write(execute_upsert)
|
await session.execute_write(execute_upsert)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during upsert: {str(e)}")
|
logger.error(f"Error during upsert: {str(e)}")
|
||||||
@@ -602,18 +605,26 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
"""
|
"""
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
try:
|
try:
|
||||||
records = await result.fetch(2) # We only need to know if there are 0, 1, or >1 nodes
|
records = await result.fetch(
|
||||||
|
2
|
||||||
|
) # We only need to know if there are 0, 1, or >1 nodes
|
||||||
|
|
||||||
if not records or records[0]["node_count"] == 0:
|
if not records or records[0]["node_count"] == 0:
|
||||||
raise ValueError(f"Neo4j: node with label '{node_label}' does not exist")
|
raise ValueError(
|
||||||
|
f"Neo4j: node with label '{node_label}' does not exist"
|
||||||
|
)
|
||||||
|
|
||||||
if records[0]["node_count"] > 1:
|
if records[0]["node_count"] > 1:
|
||||||
raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node")
|
raise ValueError(
|
||||||
|
f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node"
|
||||||
|
)
|
||||||
|
|
||||||
node = records[0]["n"]
|
node = records[0]["n"]
|
||||||
if "entity_id" not in node:
|
if "entity_id" not in node:
|
||||||
raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property")
|
raise ValueError(
|
||||||
|
f"Neo4j: node with label '{node_label}' does not have an entity_id property"
|
||||||
|
)
|
||||||
|
|
||||||
return node["entity_id"]
|
return node["entity_id"]
|
||||||
finally:
|
finally:
|
||||||
await result.consume() # Ensure result is fully consumed
|
await result.consume() # Ensure result is fully consumed
|
||||||
@@ -656,6 +667,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
|
||||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
|
MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
|
||||||
@@ -666,10 +678,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
RETURN r, source, target
|
RETURN r, source, target
|
||||||
"""
|
"""
|
||||||
result = await tx.run(
|
result = await tx.run(
|
||||||
query,
|
query,
|
||||||
source_entity_id=source_entity_id,
|
source_entity_id=source_entity_id,
|
||||||
target_entity_id=target_entity_id,
|
target_entity_id=target_entity_id,
|
||||||
properties=edge_properties
|
properties=edge_properties,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
records = await result.fetch(100)
|
records = await result.fetch(100)
|
||||||
@@ -681,7 +693,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
await result.consume() # Ensure result is consumed
|
await result.consume() # Ensure result is consumed
|
||||||
|
|
||||||
await session.execute_write(execute_upsert)
|
await session.execute_write(execute_upsert)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during edge upsert: {str(e)}")
|
logger.error(f"Error during edge upsert: {str(e)}")
|
||||||
@@ -891,7 +903,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
results = await session.run(query, {"node_id": node.id})
|
results = await session.run(query, {"node_id": node.id})
|
||||||
|
|
||||||
# Get all records and release database connection
|
# Get all records and release database connection
|
||||||
records = await results.fetch(1000) # Max neighbour nodes we can handled
|
records = await results.fetch(
|
||||||
|
1000
|
||||||
|
) # Max neighbour nodes we can handled
|
||||||
await results.consume() # Ensure results are consumed
|
await results.consume() # Ensure results are consumed
|
||||||
|
|
||||||
# Nodes not connected to start node need to check degree
|
# Nodes not connected to start node need to check degree
|
||||||
|
@@ -11,7 +11,7 @@ def direct_log(message, level="INFO", enable_output: bool = True):
|
|||||||
"""
|
"""
|
||||||
Log a message directly to stderr to ensure visibility in all processes,
|
Log a message directly to stderr to ensure visibility in all processes,
|
||||||
including the Gunicorn master process.
|
including the Gunicorn master process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: The message to log
|
message: The message to log
|
||||||
level: Log level (default: "INFO")
|
level: Log level (default: "INFO")
|
||||||
@@ -44,7 +44,13 @@ _graph_db_lock: Optional[LockType] = None
|
|||||||
class UnifiedLock(Generic[T]):
|
class UnifiedLock(Generic[T]):
|
||||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||||
|
|
||||||
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool, name: str = "unnamed", enable_logging: bool = True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
lock: Union[ProcessLock, asyncio.Lock],
|
||||||
|
is_async: bool,
|
||||||
|
name: str = "unnamed",
|
||||||
|
enable_logging: bool = True,
|
||||||
|
):
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
self._is_async = is_async
|
self._is_async = is_async
|
||||||
self._pid = os.getpid() # for debug only
|
self._pid = os.getpid() # for debug only
|
||||||
@@ -53,27 +59,47 @@ class UnifiedLock(Generic[T]):
|
|||||||
|
|
||||||
async def __aenter__(self) -> "UnifiedLock[T]":
|
async def __aenter__(self) -> "UnifiedLock[T]":
|
||||||
try:
|
try:
|
||||||
direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
await self._lock.acquire()
|
await self._lock.acquire()
|
||||||
else:
|
else:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
try:
|
try:
|
||||||
direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
else:
|
else:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def __enter__(self) -> "UnifiedLock[T]":
|
def __enter__(self) -> "UnifiedLock[T]":
|
||||||
@@ -81,12 +107,22 @@ class UnifiedLock(Generic[T]):
|
|||||||
try:
|
try:
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||||
direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
@@ -94,32 +130,62 @@ class UnifiedLock(Generic[T]):
|
|||||||
try:
|
try:
|
||||||
if self._is_async:
|
if self._is_async:
|
||||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||||
direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging)
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=self._enable_logging,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess, name="internal_lock", enable_logging=enable_logging)
|
return UnifiedLock(
|
||||||
|
lock=_internal_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="internal_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess, name="storage_lock", enable_logging=enable_logging)
|
return UnifiedLock(
|
||||||
|
lock=_storage_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="storage_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging)
|
return UnifiedLock(
|
||||||
|
lock=_pipeline_status_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="pipeline_status_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified graph database lock for ensuring atomic operations"""
|
"""return unified graph database lock for ensuring atomic operations"""
|
||||||
return UnifiedLock(lock=_graph_db_lock, is_async=not is_multiprocess, name="graph_db_lock", enable_logging=enable_logging)
|
return UnifiedLock(
|
||||||
|
lock=_graph_db_lock,
|
||||||
|
is_async=not is_multiprocess,
|
||||||
|
name="graph_db_lock",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def initialize_share_data(workers: int = 1):
|
def initialize_share_data(workers: int = 1):
|
||||||
|
@@ -522,8 +522,9 @@ async def extract_entities(
|
|||||||
maybe_edges[tuple(sorted(k))].extend(v)
|
maybe_edges[tuple(sorted(k))].extend(v)
|
||||||
|
|
||||||
from .kg.shared_storage import get_graph_db_lock
|
from .kg.shared_storage import get_graph_db_lock
|
||||||
graph_db_lock = get_graph_db_lock(enable_logging = True)
|
|
||||||
|
graph_db_lock = get_graph_db_lock(enable_logging=True)
|
||||||
|
|
||||||
# Ensure that nodes and edges are merged and upserted atomically
|
# Ensure that nodes and edges are merged and upserted atomically
|
||||||
async with graph_db_lock:
|
async with graph_db_lock:
|
||||||
all_entities_data = await asyncio.gather(
|
all_entities_data = await asyncio.gather(
|
||||||
@@ -535,7 +536,9 @@ async def extract_entities(
|
|||||||
|
|
||||||
all_relationships_data = await asyncio.gather(
|
all_relationships_data = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
|
_merge_edges_then_upsert(
|
||||||
|
k[0], k[1], v, knowledge_graph_inst, global_config
|
||||||
|
)
|
||||||
for k, v in maybe_edges.items()
|
for k, v in maybe_edges.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user