fix truncation with global_config tokenizer
This commit is contained in:
@@ -842,6 +842,7 @@ async def kg_query(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
@@ -1057,6 +1058,8 @@ async def mix_kg_vector_query(
|
||||
2. Retrieving relevant text chunks through vector similarity
|
||||
3. Combining both results for comprehensive answer generation
|
||||
"""
|
||||
# get tokenizer
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
# 1. Cache handling
|
||||
use_model_func = (
|
||||
query_param.model_func
|
||||
@@ -1111,6 +1114,7 @@ async def mix_kg_vector_query(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
|
||||
return context
|
||||
@@ -1156,6 +1160,7 @@ async def mix_kg_vector_query(
|
||||
valid_chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
if not maybe_trun_chunks:
|
||||
@@ -1213,7 +1218,6 @@ async def mix_kg_vector_query(
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
@@ -1263,6 +1267,7 @@ async def _build_query_context(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
logger.info(f"Process {os.getpid()} buidling query context...")
|
||||
if query_param.mode == "local":
|
||||
@@ -1272,6 +1277,7 @@ async def _build_query_context(
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
elif query_param.mode == "global":
|
||||
entities_context, relations_context, text_units_context = await _get_edge_data(
|
||||
@@ -1280,6 +1286,7 @@ async def _build_query_context(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
else: # hybrid mode
|
||||
ll_data = await _get_node_data(
|
||||
@@ -1288,6 +1295,7 @@ async def _build_query_context(
|
||||
entities_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
hl_data = await _get_edge_data(
|
||||
hl_keywords,
|
||||
@@ -1295,6 +1303,7 @@ async def _build_query_context(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
|
||||
(
|
||||
@@ -1341,6 +1350,7 @@ async def _get_node_data(
|
||||
entities_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
# get similar entities
|
||||
logger.info(
|
||||
@@ -1377,17 +1387,19 @@ async def _get_node_data(
|
||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||
# get entitytext chunk
|
||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
||||
)
|
||||
use_relations = await _find_most_related_edges_from_entities(
|
||||
node_datas, query_param, knowledge_graph_inst
|
||||
node_datas, query_param, knowledge_graph_inst, global_config
|
||||
)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_node_datas = len(node_datas)
|
||||
node_datas = truncate_list_by_token_size(
|
||||
node_datas,
|
||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||
max_token_size=query_param.max_token_for_local_context,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
logger.debug(
|
||||
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
||||
@@ -1481,6 +1493,7 @@ async def _find_most_related_text_unit_from_entities(
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
@@ -1562,14 +1575,15 @@ async def _find_most_related_text_unit_from_entities(
|
||||
logger.warning("No valid text units found")
|
||||
return []
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
all_text_units = sorted(
|
||||
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
||||
)
|
||||
|
||||
all_text_units = truncate_list_by_token_size(
|
||||
all_text_units,
|
||||
key=lambda x: x["data"]["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1584,6 +1598,7 @@ async def _find_most_related_edges_from_entities(
|
||||
node_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
node_names = [dp["entity_name"] for dp in node_datas]
|
||||
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
||||
@@ -1623,6 +1638,7 @@ async def _find_most_related_edges_from_entities(
|
||||
}
|
||||
all_edges_data.append(combined)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
all_edges_data = sorted(
|
||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1630,6 +1646,7 @@ async def _find_most_related_edges_from_entities(
|
||||
all_edges_data,
|
||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||
max_token_size=query_param.max_token_for_global_context,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1645,6 +1662,7 @@ async def _get_edge_data(
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
logger.info(
|
||||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||
@@ -1685,6 +1703,7 @@ async def _get_edge_data(
|
||||
}
|
||||
edge_datas.append(combined)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1692,13 +1711,14 @@ async def _get_edge_data(
|
||||
edge_datas,
|
||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||
max_token_size=query_param.max_token_for_global_context,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
use_entities, use_text_units = await asyncio.gather(
|
||||
_find_most_related_entities_from_relationships(
|
||||
edge_datas, query_param, knowledge_graph_inst
|
||||
edge_datas, query_param, knowledge_graph_inst, global_config
|
||||
),
|
||||
_find_related_text_unit_from_relationships(
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
@@ -1778,6 +1798,7 @@ async def _find_most_related_entities_from_relationships(
|
||||
edge_datas: list[dict],
|
||||
query_param: QueryParam,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
entity_names = []
|
||||
seen = set()
|
||||
@@ -1808,11 +1829,13 @@ async def _find_most_related_entities_from_relationships(
|
||||
combined = {**node, "entity_name": entity_name, "rank": degree}
|
||||
node_datas.append(combined)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_node_datas = len(node_datas)
|
||||
node_datas = truncate_list_by_token_size(
|
||||
node_datas,
|
||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||
max_token_size=query_param.max_token_for_local_context,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
logger.debug(
|
||||
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
||||
@@ -1826,6 +1849,7 @@ async def _find_related_text_unit_from_relationships(
|
||||
query_param: QueryParam,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict[str, str],
|
||||
):
|
||||
text_units = [
|
||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||
@@ -1867,10 +1891,12 @@ async def _find_related_text_unit_from_relationships(
|
||||
logger.warning("No valid text chunks after filtering")
|
||||
return []
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
truncated_text_units = truncate_list_by_token_size(
|
||||
valid_text_units,
|
||||
key=lambda x: x["data"]["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1941,10 +1967,12 @@ async def naive_query(
|
||||
logger.warning("No valid chunks found after filtering")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
valid_chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
if not maybe_trun_chunks:
|
||||
@@ -1982,7 +2010,6 @@ async def naive_query(
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
@@ -2101,6 +2128,7 @@ async def kg_query_with_keywords(
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
global_config,
|
||||
)
|
||||
if not context:
|
||||
return PROMPTS["fail_response"]
|
||||
|
@@ -424,7 +424,7 @@ def is_float_regex(value: str) -> bool:
|
||||
|
||||
|
||||
def truncate_list_by_token_size(
|
||||
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
||||
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
|
||||
) -> list[int]:
|
||||
"""Truncate a list of data by token size"""
|
||||
if max_token_size <= 0:
|
||||
|
Reference in New Issue
Block a user