Merge branch 'main' into add-multi-worker-support

This commit is contained in:
yangdx
2025-03-01 15:55:37 +08:00
31 changed files with 1755 additions and 1371 deletions

View File

@@ -323,6 +323,7 @@ async def _merge_edges_then_upsert(
tgt_id=tgt_id,
description=description,
keywords=keywords,
source_id=source_id,
)
return edge_data
@@ -365,7 +366,7 @@ async def extract_entities(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(entity_types),
entity_types=", ".join(entity_types),
language=language,
)
# add example's format
@@ -562,6 +563,7 @@ async def extract_entities(
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
"source_id": dp["source_id"],
}
for dp in all_entities_data
}
@@ -572,6 +574,7 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"source_id": dp["source_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
@@ -595,7 +598,7 @@ async def kg_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
) -> str:
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
@@ -1127,7 +1130,7 @@ async def _get_node_data(
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"],
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
)
logger.debug(
@@ -1310,7 +1313,7 @@ async def _find_most_related_edges_from_entities(
)
all_edges_data = truncate_list_by_token_size(
all_edges_data,
key=lambda x: x["description"],
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
)
@@ -1364,7 +1367,7 @@ async def _get_edge_data(
)
edge_datas = truncate_list_by_token_size(
edge_datas,
key=lambda x: x["description"],
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
)
use_entities, use_text_units = await asyncio.gather(
@@ -1468,7 +1471,7 @@ async def _find_most_related_entities_from_relationships(
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"],
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
)
logger.debug(