Merge branch 'main' into main

This commit is contained in:
Alex Z
2025-04-05 15:27:59 -07:00
committed by GitHub
77 changed files with 5920 additions and 5192 deletions

View File

@@ -26,7 +26,6 @@ from .utils import (
CacheData,
statistic_data,
get_conversation_turns,
verbose_debug,
)
from .base import (
BaseGraphStorage,
@@ -442,6 +441,13 @@ async def extract_entities(
processed_chunks = 0
total_chunks = len(ordered_chunks)
total_entities_count = 0
total_relations_count = 0
# Get lock manager from shared storage
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
async def _user_llm_func_with_cache(
input_text: str, history_messages: list[dict[str, str]] = None
@@ -540,7 +546,7 @@ async def extract_entities(
chunk_key_dp (tuple[str, TextChunkSchema]):
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
"""
nonlocal processed_chunks
nonlocal processed_chunks, total_entities_count, total_relations_count
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
@@ -598,102 +604,74 @@ async def extract_entities(
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return dict(maybe_nodes), dict(maybe_edges)
tasks = [_process_single_content(c) for c in ordered_chunks]
results = await asyncio.gather(*tasks)
# Use graph database lock to ensure atomic merges and updates
chunk_entities_data = []
chunk_relationships_data = []
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
# Ensure that nodes and edges are merged and upserted atomically
async with graph_db_lock:
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
_merge_edges_then_upsert(
k[0], k[1], v, knowledge_graph_inst, global_config
async with graph_db_lock:
# Process and update entities
for entity_name, entities in maybe_nodes.items():
entity_data = await _merge_nodes_then_upsert(
entity_name, entities, knowledge_graph_inst, global_config
)
for k, v in maybe_edges.items()
]
)
chunk_entities_data.append(entity_data)
if not (all_entities_data or all_relationships_data):
log_message = "Didn't extract any entities and relationships."
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return
# Process and update relationships
for edge_key, edges in maybe_edges.items():
# Ensure edge direction consistency
sorted_edge_key = tuple(sorted(edge_key))
edge_data = await _merge_edges_then_upsert(
sorted_edge_key[0],
sorted_edge_key[1],
edges,
knowledge_graph_inst,
global_config,
)
chunk_relationships_data.append(edge_data)
if not all_entities_data:
log_message = "Didn't extract any entities"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
if not all_relationships_data:
log_message = "Didn't extract any relationships"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update vector database (within the same lock to ensure atomicity)
if entity_vdb is not None and chunk_entities_data:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in chunk_entities_data
}
await entity_vdb.upsert(data_for_vdb)
log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
if relationships_vdb is not None and chunk_relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in chunk_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
# Update counters
total_entities_count += len(chunk_entities_data)
total_relations_count += len(chunk_relationships_data)
# Handle all chunks in parallel
tasks = [_process_single_content(c) for c in ordered_chunks]
await asyncio.gather(*tasks)
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
verbose_debug(
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
)
verbose_debug(f"New relationships:{all_relationships_data}")
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
if relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in all_relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
async def kg_query(
@@ -720,8 +698,7 @@ async def kg_query(
if cached_response is not None:
return cached_response
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
@@ -817,6 +794,38 @@ async def kg_query(
return response
async def get_keywords_from_query(
query: str,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
Retrieves high-level and low-level keywords for RAG operations.
This function checks if keywords are already provided in query parameters,
and if not, extracts them from the query text using LLM.
Args:
query: The user's query text
query_param: Query parameters that may contain pre-defined keywords
global_config: Global configuration dictionary
hashing_kv: Optional key-value storage for caching results
Returns:
A tuple containing (high_level_keywords, low_level_keywords)
"""
# Check if pre-defined keywords are already provided
if query_param.hl_keywords or query_param.ll_keywords:
return query_param.hl_keywords, query_param.ll_keywords
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv
)
return hl_keywords, ll_keywords
async def extract_keywords_only(
text: str,
param: QueryParam,
@@ -957,8 +966,7 @@ async def mix_kg_vector_query(
# 2. Execute knowledge graph and vector searches in parallel
async def get_kg_context():
try:
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
@@ -1339,7 +1347,9 @@ async def _get_node_data(
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"], t["file_path"]])
text_units_section_list.append(
[i, t["content"], t.get("file_path", "unknown_source")]
)
text_units_context = list_of_list_to_csv(text_units_section_list)
return entities_context, relations_context, text_units_context
@@ -2043,16 +2053,13 @@ async def query_with_keywords(
Query response or async iterator
"""
# Extract keywords
hl_keywords, ll_keywords = await extract_keywords_only(
text=query,
param=param,
hl_keywords, ll_keywords = await get_keywords_from_query(
query=query,
query_param=param,
global_config=global_config,
hashing_kv=hashing_kv,
)
param.hl_keywords = hl_keywords
param.ll_keywords = ll_keywords
# Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords)