fix truncation with global_config tokenizer
This commit is contained in:
@@ -842,6 +842,7 @@ async def kg_query(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
global_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
@@ -1057,6 +1058,8 @@ async def mix_kg_vector_query(
|
|||||||
2. Retrieving relevant text chunks through vector similarity
|
2. Retrieving relevant text chunks through vector similarity
|
||||||
3. Combining both results for comprehensive answer generation
|
3. Combining both results for comprehensive answer generation
|
||||||
"""
|
"""
|
||||||
|
# get tokenizer
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
# 1. Cache handling
|
# 1. Cache handling
|
||||||
use_model_func = (
|
use_model_func = (
|
||||||
query_param.model_func
|
query_param.model_func
|
||||||
@@ -1111,6 +1114,7 @@ async def mix_kg_vector_query(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
global_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
@@ -1156,6 +1160,7 @@ async def mix_kg_vector_query(
|
|||||||
valid_chunks,
|
valid_chunks,
|
||||||
key=lambda x: x["content"],
|
key=lambda x: x["content"],
|
||||||
max_token_size=query_param.max_token_for_text_unit,
|
max_token_size=query_param.max_token_for_text_unit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not maybe_trun_chunks:
|
if not maybe_trun_chunks:
|
||||||
@@ -1213,7 +1218,6 @@ async def mix_kg_vector_query(
|
|||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return sys_prompt
|
return sys_prompt
|
||||||
|
|
||||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
|
||||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
@@ -1263,6 +1267,7 @@ async def _build_query_context(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
|
global_config: dict[str, str],
|
||||||
):
|
):
|
||||||
logger.info(f"Process {os.getpid()} buidling query context...")
|
logger.info(f"Process {os.getpid()} buidling query context...")
|
||||||
if query_param.mode == "local":
|
if query_param.mode == "local":
|
||||||
@@ -1272,6 +1277,7 @@ async def _build_query_context(
|
|||||||
entities_vdb,
|
entities_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
global_config,
|
||||||
)
|
)
|
||||||
elif query_param.mode == "global":
|
elif query_param.mode == "global":
|
||||||
entities_context, relations_context, text_units_context = await _get_edge_data(
|
entities_context, relations_context, text_units_context = await _get_edge_data(
|
||||||
@@ -1280,6 +1286,7 @@ async def _build_query_context(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
global_config,
|
||||||
)
|
)
|
||||||
else: # hybrid mode
|
else: # hybrid mode
|
||||||
ll_data = await _get_node_data(
|
ll_data = await _get_node_data(
|
||||||
@@ -1288,6 +1295,7 @@ async def _build_query_context(
|
|||||||
entities_vdb,
|
entities_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
global_config,
|
||||||
)
|
)
|
||||||
hl_data = await _get_edge_data(
|
hl_data = await _get_edge_data(
|
||||||
hl_keywords,
|
hl_keywords,
|
||||||
@@ -1295,6 +1303,7 @@ async def _build_query_context(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
global_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -1341,6 +1350,7 @@ async def _get_node_data(
|
|||||||
entities_vdb: BaseVectorStorage,
|
entities_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
|
global_config: dict[str, str],
|
||||||
):
|
):
|
||||||
# get similar entities
|
# get similar entities
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1377,17 +1387,19 @@ async def _get_node_data(
|
|||||||
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
||||||
# get entitytext chunk
|
# get entitytext chunk
|
||||||
use_text_units = await _find_most_related_text_unit_from_entities(
|
use_text_units = await _find_most_related_text_unit_from_entities(
|
||||||
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
||||||
)
|
)
|
||||||
use_relations = await _find_most_related_edges_from_entities(
|
use_relations = await _find_most_related_edges_from_entities(
|
||||||
node_datas, query_param, knowledge_graph_inst
|
node_datas, query_param, knowledge_graph_inst, global_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
len_node_datas = len(node_datas)
|
len_node_datas = len(node_datas)
|
||||||
node_datas = truncate_list_by_token_size(
|
node_datas = truncate_list_by_token_size(
|
||||||
node_datas,
|
node_datas,
|
||||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_local_context,
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
||||||
@@ -1481,6 +1493,7 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
global_config: dict[str, str],
|
||||||
):
|
):
|
||||||
text_units = [
|
text_units = [
|
||||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||||
@@ -1562,14 +1575,15 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
logger.warning("No valid text units found")
|
logger.warning("No valid text units found")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
all_text_units = sorted(
|
all_text_units = sorted(
|
||||||
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
||||||
)
|
)
|
||||||
|
|
||||||
all_text_units = truncate_list_by_token_size(
|
all_text_units = truncate_list_by_token_size(
|
||||||
all_text_units,
|
all_text_units,
|
||||||
key=lambda x: x["data"]["content"],
|
key=lambda x: x["data"]["content"],
|
||||||
max_token_size=query_param.max_token_for_text_unit,
|
max_token_size=query_param.max_token_for_text_unit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -1584,6 +1598,7 @@ async def _find_most_related_edges_from_entities(
|
|||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
global_config: dict[str, str],
|
||||||
):
|
):
|
||||||
node_names = [dp["entity_name"] for dp in node_datas]
|
node_names = [dp["entity_name"] for dp in node_datas]
|
||||||
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
||||||
@@ -1623,6 +1638,7 @@ async def _find_most_related_edges_from_entities(
|
|||||||
}
|
}
|
||||||
all_edges_data.append(combined)
|
all_edges_data.append(combined)
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
all_edges_data = sorted(
|
all_edges_data = sorted(
|
||||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||||
)
|
)
|
||||||
@@ -1630,6 +1646,7 @@ async def _find_most_related_edges_from_entities(
|
|||||||
all_edges_data,
|
all_edges_data,
|
||||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_global_context,
|
max_token_size=query_param.max_token_for_global_context,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -1645,6 +1662,7 @@ async def _get_edge_data(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
|
global_config: dict[str, str],
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
||||||
@@ -1685,6 +1703,7 @@ async def _get_edge_data(
|
|||||||
}
|
}
|
||||||
edge_datas.append(combined)
|
edge_datas.append(combined)
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
edge_datas = sorted(
|
edge_datas = sorted(
|
||||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||||
)
|
)
|
||||||
@@ -1692,13 +1711,14 @@ async def _get_edge_data(
|
|||||||
edge_datas,
|
edge_datas,
|
||||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_global_context,
|
max_token_size=query_param.max_token_for_global_context,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
use_entities, use_text_units = await asyncio.gather(
|
use_entities, use_text_units = await asyncio.gather(
|
||||||
_find_most_related_entities_from_relationships(
|
_find_most_related_entities_from_relationships(
|
||||||
edge_datas, query_param, knowledge_graph_inst
|
edge_datas, query_param, knowledge_graph_inst, global_config
|
||||||
),
|
),
|
||||||
_find_related_text_unit_from_relationships(
|
_find_related_text_unit_from_relationships(
|
||||||
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1778,6 +1798,7 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
global_config: dict[str, str],
|
||||||
):
|
):
|
||||||
entity_names = []
|
entity_names = []
|
||||||
seen = set()
|
seen = set()
|
||||||
@@ -1808,11 +1829,13 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
combined = {**node, "entity_name": entity_name, "rank": degree}
|
combined = {**node, "entity_name": entity_name, "rank": degree}
|
||||||
node_datas.append(combined)
|
node_datas.append(combined)
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
len_node_datas = len(node_datas)
|
len_node_datas = len(node_datas)
|
||||||
node_datas = truncate_list_by_token_size(
|
node_datas = truncate_list_by_token_size(
|
||||||
node_datas,
|
node_datas,
|
||||||
key=lambda x: x["description"] if x["description"] is not None else "",
|
key=lambda x: x["description"] if x["description"] is not None else "",
|
||||||
max_token_size=query_param.max_token_for_local_context,
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
||||||
@@ -1826,6 +1849,7 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
global_config: dict[str, str],
|
||||||
):
|
):
|
||||||
text_units = [
|
text_units = [
|
||||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||||
@@ -1867,10 +1891,12 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
logger.warning("No valid text chunks after filtering")
|
logger.warning("No valid text chunks after filtering")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
truncated_text_units = truncate_list_by_token_size(
|
truncated_text_units = truncate_list_by_token_size(
|
||||||
valid_text_units,
|
valid_text_units,
|
||||||
key=lambda x: x["data"]["content"],
|
key=lambda x: x["data"]["content"],
|
||||||
max_token_size=query_param.max_token_for_text_unit,
|
max_token_size=query_param.max_token_for_text_unit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -1941,10 +1967,12 @@ async def naive_query(
|
|||||||
logger.warning("No valid chunks found after filtering")
|
logger.warning("No valid chunks found after filtering")
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
maybe_trun_chunks = truncate_list_by_token_size(
|
maybe_trun_chunks = truncate_list_by_token_size(
|
||||||
valid_chunks,
|
valid_chunks,
|
||||||
key=lambda x: x["content"],
|
key=lambda x: x["content"],
|
||||||
max_token_size=query_param.max_token_for_text_unit,
|
max_token_size=query_param.max_token_for_text_unit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not maybe_trun_chunks:
|
if not maybe_trun_chunks:
|
||||||
@@ -1982,7 +2010,6 @@ async def naive_query(
|
|||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return sys_prompt
|
return sys_prompt
|
||||||
|
|
||||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
|
||||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
||||||
|
|
||||||
@@ -2101,6 +2128,7 @@ async def kg_query_with_keywords(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
global_config,
|
||||||
)
|
)
|
||||||
if not context:
|
if not context:
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
@@ -424,7 +424,7 @@ def is_float_regex(value: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(
|
def truncate_list_by_token_size(
|
||||||
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""Truncate a list of data by token size"""
|
"""Truncate a list of data by token size"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
|
Reference in New Issue
Block a user