fix truncation with global_config tokenizer

This commit is contained in:
drahnreb
2025-04-17 13:09:52 +02:00
parent 0e6771b503
commit 0f949dd5d7
2 changed files with 36 additions and 8 deletions

View File

@@ -842,6 +842,7 @@ async def kg_query(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
if query_param.only_need_context: if query_param.only_need_context:
@@ -1057,6 +1058,8 @@ async def mix_kg_vector_query(
2. Retrieving relevant text chunks through vector similarity 2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation 3. Combining both results for comprehensive answer generation
""" """
# get tokenizer
tokenizer: Tokenizer = global_config["tokenizer"]
# 1. Cache handling # 1. Cache handling
use_model_func = ( use_model_func = (
query_param.model_func query_param.model_func
@@ -1111,6 +1114,7 @@ async def mix_kg_vector_query(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
return context return context
@@ -1156,6 +1160,7 @@ async def mix_kg_vector_query(
valid_chunks, valid_chunks,
key=lambda x: x["content"], key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
) )
if not maybe_trun_chunks: if not maybe_trun_chunks:
@@ -1213,7 +1218,6 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
@@ -1263,6 +1267,7 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str],
): ):
logger.info(f"Process {os.getpid()} buidling query context...") logger.info(f"Process {os.getpid()} buidling query context...")
if query_param.mode == "local": if query_param.mode == "local":
@@ -1272,6 +1277,7 @@ async def _build_query_context(
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
elif query_param.mode == "global": elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data( entities_context, relations_context, text_units_context = await _get_edge_data(
@@ -1280,6 +1286,7 @@ async def _build_query_context(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
else: # hybrid mode else: # hybrid mode
ll_data = await _get_node_data( ll_data = await _get_node_data(
@@ -1288,6 +1295,7 @@ async def _build_query_context(
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
hl_data = await _get_edge_data( hl_data = await _get_edge_data(
hl_keywords, hl_keywords,
@@ -1295,6 +1303,7 @@ async def _build_query_context(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
( (
@@ -1341,6 +1350,7 @@ async def _get_node_data(
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str],
): ):
# get similar entities # get similar entities
logger.info( logger.info(
@@ -1377,17 +1387,19 @@ async def _get_node_data(
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
# get entitytext chunk # get entitytext chunk
use_text_units = await _find_most_related_text_unit_from_entities( use_text_units = await _find_most_related_text_unit_from_entities(
node_datas, query_param, text_chunks_db, knowledge_graph_inst node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
) )
use_relations = await _find_most_related_edges_from_entities( use_relations = await _find_most_related_edges_from_entities(
node_datas, query_param, knowledge_graph_inst node_datas, query_param, knowledge_graph_inst, global_config
) )
tokenizer: Tokenizer = global_config["tokenizer"]
len_node_datas = len(node_datas) len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(
node_datas, node_datas,
key=lambda x: x["description"] if x["description"] is not None else "", key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context, max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
) )
logger.debug( logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1481,6 +1493,7 @@ async def _find_most_related_text_unit_from_entities(
query_param: QueryParam, query_param: QueryParam,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1562,14 +1575,15 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found") logger.warning("No valid text units found")
return [] return []
tokenizer: Tokenizer = global_config["tokenizer"]
all_text_units = sorted( all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
) )
all_text_units = truncate_list_by_token_size( all_text_units = truncate_list_by_token_size(
all_text_units, all_text_units,
key=lambda x: x["data"]["content"], key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
) )
logger.debug( logger.debug(
@@ -1584,6 +1598,7 @@ async def _find_most_related_edges_from_entities(
node_datas: list[dict], node_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
node_names = [dp["entity_name"] for dp in node_datas] node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
@@ -1623,6 +1638,7 @@ async def _find_most_related_edges_from_entities(
} }
all_edges_data.append(combined) all_edges_data.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
all_edges_data = sorted( all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1630,6 +1646,7 @@ async def _find_most_related_edges_from_entities(
all_edges_data, all_edges_data,
key=lambda x: x["description"] if x["description"] is not None else "", key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
) )
logger.debug( logger.debug(
@@ -1645,6 +1662,7 @@ async def _get_edge_data(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str],
): ):
logger.info( logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
@@ -1685,6 +1703,7 @@ async def _get_edge_data(
} }
edge_datas.append(combined) edge_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1692,13 +1711,14 @@ async def _get_edge_data(
edge_datas, edge_datas,
key=lambda x: x["description"] if x["description"] is not None else "", key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
tokenizer=tokenizer,
) )
use_entities, use_text_units = await asyncio.gather( use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships( _find_most_related_entities_from_relationships(
edge_datas, query_param, knowledge_graph_inst edge_datas, query_param, knowledge_graph_inst, global_config
), ),
_find_related_text_unit_from_relationships( _find_related_text_unit_from_relationships(
edge_datas, query_param, text_chunks_db, knowledge_graph_inst edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
), ),
) )
logger.info( logger.info(
@@ -1778,6 +1798,7 @@ async def _find_most_related_entities_from_relationships(
edge_datas: list[dict], edge_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
entity_names = [] entity_names = []
seen = set() seen = set()
@@ -1808,11 +1829,13 @@ async def _find_most_related_entities_from_relationships(
combined = {**node, "entity_name": entity_name, "rank": degree} combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined) node_datas.append(combined)
tokenizer: Tokenizer = global_config["tokenizer"]
len_node_datas = len(node_datas) len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(
node_datas, node_datas,
key=lambda x: x["description"] if x["description"] is not None else "", key=lambda x: x["description"] if x["description"] is not None else "",
max_token_size=query_param.max_token_for_local_context, max_token_size=query_param.max_token_for_local_context,
tokenizer=tokenizer,
) )
logger.debug( logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1826,6 +1849,7 @@ async def _find_related_text_unit_from_relationships(
query_param: QueryParam, query_param: QueryParam,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict[str, str],
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1867,10 +1891,12 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering") logger.warning("No valid text chunks after filtering")
return [] return []
tokenizer: Tokenizer = global_config["tokenizer"]
truncated_text_units = truncate_list_by_token_size( truncated_text_units = truncate_list_by_token_size(
valid_text_units, valid_text_units,
key=lambda x: x["data"]["content"], key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
) )
logger.debug( logger.debug(
@@ -1941,10 +1967,12 @@ async def naive_query(
logger.warning("No valid chunks found after filtering") logger.warning("No valid chunks found after filtering")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
tokenizer: Tokenizer = global_config["tokenizer"]
maybe_trun_chunks = truncate_list_by_token_size( maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks, valid_chunks,
key=lambda x: x["content"], key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
) )
if not maybe_trun_chunks: if not maybe_trun_chunks:
@@ -1982,7 +2010,6 @@ async def naive_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}") logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
@@ -2101,6 +2128,7 @@ async def kg_query_with_keywords(
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
global_config,
) )
if not context: if not context:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]

View File

@@ -424,7 +424,7 @@ def is_float_regex(value: str) -> bool:
def truncate_list_by_token_size( def truncate_list_by_token_size(
list_data: list[Any], key: Callable[[Any], str], max_token_size: int list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
) -> list[int]: ) -> list[int]:
"""Truncate a list of data by token size""" """Truncate a list of data by token size"""
if max_token_size <= 0: if max_token_size <= 0: