Optmize parallel processing on chunks extraction
This commit is contained in:
@@ -25,7 +25,6 @@ from .utils import (
|
|||||||
CacheData,
|
CacheData,
|
||||||
statistic_data,
|
statistic_data,
|
||||||
get_conversation_turns,
|
get_conversation_turns,
|
||||||
verbose_debug,
|
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@@ -441,6 +440,12 @@ async def extract_entities(
|
|||||||
|
|
||||||
processed_chunks = 0
|
processed_chunks = 0
|
||||||
total_chunks = len(ordered_chunks)
|
total_chunks = len(ordered_chunks)
|
||||||
|
total_entities_count = 0
|
||||||
|
total_relations_count = 0
|
||||||
|
|
||||||
|
# Get lock manager from shared storage
|
||||||
|
from .kg.shared_storage import get_graph_db_lock
|
||||||
|
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||||
|
|
||||||
async def _user_llm_func_with_cache(
|
async def _user_llm_func_with_cache(
|
||||||
input_text: str, history_messages: list[dict[str, str]] = None
|
input_text: str, history_messages: list[dict[str, str]] = None
|
||||||
@@ -539,7 +544,7 @@ async def extract_entities(
|
|||||||
chunk_key_dp (tuple[str, TextChunkSchema]):
|
chunk_key_dp (tuple[str, TextChunkSchema]):
|
||||||
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
||||||
"""
|
"""
|
||||||
nonlocal processed_chunks
|
nonlocal processed_chunks, total_entities_count, total_relations_count
|
||||||
chunk_key = chunk_key_dp[0]
|
chunk_key = chunk_key_dp[0]
|
||||||
chunk_dp = chunk_key_dp[1]
|
chunk_dp = chunk_key_dp[1]
|
||||||
content = chunk_dp["content"]
|
content = chunk_dp["content"]
|
||||||
@@ -597,75 +602,30 @@ async def extract_entities(
|
|||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
return dict(maybe_nodes), dict(maybe_edges)
|
|
||||||
|
|
||||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
# Use graph database lock to ensure atomic merges and updates
|
||||||
results = await asyncio.gather(*tasks)
|
chunk_entities_data = []
|
||||||
|
chunk_relationships_data = []
|
||||||
|
|
||||||
maybe_nodes = defaultdict(list)
|
|
||||||
maybe_edges = defaultdict(list)
|
|
||||||
for m_nodes, m_edges in results:
|
|
||||||
for k, v in m_nodes.items():
|
|
||||||
maybe_nodes[k].extend(v)
|
|
||||||
for k, v in m_edges.items():
|
|
||||||
maybe_edges[tuple(sorted(k))].extend(v)
|
|
||||||
|
|
||||||
from .kg.shared_storage import get_graph_db_lock
|
|
||||||
|
|
||||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
|
||||||
|
|
||||||
# Ensure that nodes and edges are merged and upserted atomically
|
|
||||||
async with graph_db_lock:
|
async with graph_db_lock:
|
||||||
# serial processing nodes under lock
|
# Process and update entities
|
||||||
all_entities_data = []
|
for entity_name, entities in maybe_nodes.items():
|
||||||
for k, v in maybe_nodes.items():
|
entity_data = await _merge_nodes_then_upsert(
|
||||||
entity_data = await _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
entity_name, entities, knowledge_graph_inst, global_config
|
||||||
all_entities_data.append(entity_data)
|
)
|
||||||
|
chunk_entities_data.append(entity_data)
|
||||||
|
|
||||||
# serial processing edges under lock
|
# Process and update relationships
|
||||||
all_relationships_data = []
|
for edge_key, edges in maybe_edges.items():
|
||||||
for k, v in maybe_edges.items():
|
# Ensure edge direction consistency
|
||||||
|
sorted_edge_key = tuple(sorted(edge_key))
|
||||||
edge_data = await _merge_edges_then_upsert(
|
edge_data = await _merge_edges_then_upsert(
|
||||||
k[0], k[1], v, knowledge_graph_inst, global_config
|
sorted_edge_key[0], sorted_edge_key[1], edges, knowledge_graph_inst, global_config
|
||||||
)
|
)
|
||||||
all_relationships_data.append(edge_data)
|
chunk_relationships_data.append(edge_data)
|
||||||
|
|
||||||
if not (all_entities_data or all_relationships_data):
|
# Update vector database (within the same lock to ensure atomicity)
|
||||||
log_message = "Didn't extract any entities and relationships."
|
if entity_vdb is not None and chunk_entities_data:
|
||||||
logger.info(log_message)
|
|
||||||
if pipeline_status is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not all_entities_data:
|
|
||||||
log_message = "Didn't extract any entities"
|
|
||||||
logger.info(log_message)
|
|
||||||
if pipeline_status is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
if not all_relationships_data:
|
|
||||||
log_message = "Didn't extract any relationships"
|
|
||||||
logger.info(log_message)
|
|
||||||
if pipeline_status is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
|
|
||||||
logger.info(log_message)
|
|
||||||
if pipeline_status is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
verbose_debug(
|
|
||||||
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
|
||||||
)
|
|
||||||
verbose_debug(f"New relationships:{all_relationships_data}")
|
|
||||||
|
|
||||||
if entity_vdb is not None:
|
|
||||||
data_for_vdb = {
|
data_for_vdb = {
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||||
"entity_name": dp["entity_name"],
|
"entity_name": dp["entity_name"],
|
||||||
@@ -674,11 +634,11 @@ async def extract_entities(
|
|||||||
"source_id": dp["source_id"],
|
"source_id": dp["source_id"],
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
"file_path": dp.get("file_path", "unknown_source"),
|
||||||
}
|
}
|
||||||
for dp in all_entities_data
|
for dp in chunk_entities_data
|
||||||
}
|
}
|
||||||
await entity_vdb.upsert(data_for_vdb)
|
await entity_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
if relationships_vdb is not None:
|
if relationships_vdb is not None and chunk_relationships_data:
|
||||||
data_for_vdb = {
|
data_for_vdb = {
|
||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
@@ -688,10 +648,25 @@ async def extract_entities(
|
|||||||
"source_id": dp["source_id"],
|
"source_id": dp["source_id"],
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
"file_path": dp.get("file_path", "unknown_source"),
|
||||||
}
|
}
|
||||||
for dp in all_relationships_data
|
for dp in chunk_relationships_data
|
||||||
}
|
}
|
||||||
await relationships_vdb.upsert(data_for_vdb)
|
await relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
|
# Update counters
|
||||||
|
total_entities_count += len(chunk_entities_data)
|
||||||
|
total_relations_count += len(chunk_relationships_data)
|
||||||
|
|
||||||
|
# Handle all chunks in parallel
|
||||||
|
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
|
||||||
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
|
||||||
async def kg_query(
|
async def kg_query(
|
||||||
query: str,
|
query: str,
|
||||||
|
Reference in New Issue
Block a user