fix truncation with global_config tokenizer

This commit is contained in:
drahnreb
2025-04-17 13:09:52 +02:00
parent 0e6771b503
commit 0f949dd5d7
2 changed files with 36 additions and 8 deletions

View File

@@ -842,6 +842,7 @@ async def kg_query(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
if query_param.only_need_context:
@@ -1057,6 +1058,8 @@ async def mix_kg_vector_query(
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
"""
# get tokenizer
tokenizer: Tokenizer = global_config["tokenizer"]
# 1. Cache handling
use_model_func = (
query_param.model_func
@@ -1111,6 +1114,7 @@ async def mix_kg_vector_query(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
return context
@@ -1156,6 +1160,7 @@ async def mix_kg_vector_query(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
if not maybe_trun_chunks:
@@ -1213,7 +1218,6 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt:
return sys_prompt
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
@@ -1263,6 +1267,7 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
logger.info(f"Process {os.getpid()} buidling query context...")
if query_param.mode == "local":
@@ -1272,6 +1277,7 @@ async def _build_query_context(
entities_vdb,
text_chunks_db,
query_param,
global_config,
)
elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data(
@@ -1280,6 +1286,7 @@ async def _build_query_context(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
else: # hybrid mode
ll_data = await _get_node_data(
@@ -1288,6 +1295,7 @@ async def _build_query_context(
entities_vdb,
text_chunks_db,
query_param,
global_config,
)
hl_data = await _get_edge_data(
hl_keywords,
@@ -1295,6 +1303,7 @@ async def _build_query_context(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
(
@@ -1341,6 +1350,7 @@ async def _get_node_data(
entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
# get similar entities
logger.info(
@@ -1377,17 +1387,19 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
)
use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst
node_datas, query_param, knowledge_graph_inst, global_config
)
tokenizer: Tokenizer = global_config["tokenizer"]
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1481,6 +1493,7 @@ async def _find_most_related_text_unit_from_entities(
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1562,14 +1575,15 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found")
return []
tokenizer: Tokenizer = global_config["tokenizer"]
all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
)
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
@@ -1584,6 +1598,7 @@ async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
@@ -1623,6 +1638,7 @@ async def _find_most_related_edges_from_entities(
}
all_edges_data.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1630,6 +1646,7 @@ async def _find_most_related_edges_from_entities(
all_edges_data,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
)
logger.debug(
@@ -1645,6 +1662,7 @@ async def _get_edge_data(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
):
logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
@@ -1685,6 +1703,7 @@ async def _get_edge_data(
}
edge_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1692,13 +1711,14 @@ async def _get_edge_data(
edge_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst
edge_datas, query_param, knowledge_graph_inst, global_config
),
_find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
),
)
logger.info(
@@ -1778,6 +1798,7 @@ async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
entity_names = []
seen = set()
@@ -1808,11 +1829,13 @@ async def _find_most_related_entities_from_relationships(
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1826,6 +1849,7 @@ async def _find_related_text_unit_from_relationships(
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
):
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1867,10 +1891,12 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering")
return []
tokenizer: Tokenizer = global_config["tokenizer"]
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
@@ -1941,10 +1967,12 @@ async def naive_query(
logger.warning("No valid chunks found after filtering")
return PROMPTS["fail_response"]
tokenizer: Tokenizer = global_config["tokenizer"]
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
if not maybe_trun_chunks:
@@ -1982,7 +2010,6 @@ async def naive_query(
if query_param.only_need_prompt:
return sys_prompt
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
@@ -2101,6 +2128,7 @@ async def kg_query_with_keywords(
relationships_vdb,
text_chunks_db,
query_param,
global_config,
)
if not context:
return PROMPTS["fail_response"]

View File

@@ -424,7 +424,7 @@ def is_float_regex(value: str) -> bool:
def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
) -> list[int]:
"""Truncate a list of data by token size"""
if max_token_size <= 0: