From 0f949dd5d7ed5034f6233cb6f21616a0fd8897d6 Mon Sep 17 00:00:00 2001 From: drahnreb <25883607+drahnreb@users.noreply.github.com> Date: Thu, 17 Apr 2025 13:09:52 +0200 Subject: [PATCH] fix truncation with global_config tokenizer --- lightrag/operate.py | 42 +++++++++++++++++++++++++++++++++++------- lightrag/utils.py | 2 +- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 8f79dcf8..13d60289 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -842,6 +842,7 @@ async def kg_query( relationships_vdb, text_chunks_db, query_param, + global_config, ) if query_param.only_need_context: @@ -1057,6 +1058,8 @@ async def mix_kg_vector_query( 2. Retrieving relevant text chunks through vector similarity 3. Combining both results for comprehensive answer generation """ + # get tokenizer + tokenizer: Tokenizer = global_config["tokenizer"] # 1. Cache handling use_model_func = ( query_param.model_func @@ -1111,6 +1114,7 @@ async def mix_kg_vector_query( relationships_vdb, text_chunks_db, query_param, + global_config, ) return context @@ -1156,6 +1160,7 @@ async def mix_kg_vector_query( valid_chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, ) if not maybe_trun_chunks: @@ -1213,7 +1218,6 @@ async def mix_kg_vector_query( if query_param.only_need_prompt: return sys_prompt - tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}") @@ -1263,6 +1267,7 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, + global_config: dict[str, str], ): logger.info(f"Process {os.getpid()} buidling query context...") if query_param.mode == "local": @@ -1272,6 +1277,7 @@ async def _build_query_context( entities_vdb, text_chunks_db, query_param, + global_config, ) elif query_param.mode == "global": entities_context, relations_context, text_units_context = await _get_edge_data( @@ -1280,6 +1286,7 @@ async def _build_query_context( relationships_vdb, text_chunks_db, query_param, + global_config, ) else: # hybrid mode ll_data = await _get_node_data( @@ -1288,6 +1295,7 @@ async def _build_query_context( entities_vdb, text_chunks_db, query_param, + global_config, ) hl_data = await _get_edge_data( hl_keywords, @@ -1295,6 +1303,7 @@ async def _build_query_context( relationships_vdb, text_chunks_db, query_param, + global_config, ) ( @@ -1341,6 +1350,7 @@ async def _get_node_data( entities_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, + global_config: dict[str, str], ): # get similar entities logger.info( @@ -1377,17 +1387,19 @@ async def _get_node_data( ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( - node_datas, query_param, text_chunks_db, knowledge_graph_inst + node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config ) use_relations = await _find_most_related_edges_from_entities( - node_datas, query_param, knowledge_graph_inst + node_datas, query_param, knowledge_graph_inst, global_config ) + tokenizer: Tokenizer = global_config["tokenizer"] len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_local_context, + tokenizer=tokenizer, ) logger.debug( f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" @@ -1481,6 +1493,7 @@ async def _find_most_related_text_unit_from_entities( query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, + global_config: dict[str, str], ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -1562,14 +1575,15 @@ async def _find_most_related_text_unit_from_entities( logger.warning("No valid text units found") return [] + tokenizer: Tokenizer = global_config["tokenizer"] all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) - all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, ) logger.debug( @@ -1584,6 +1598,7 @@ async def _find_most_related_edges_from_entities( node_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, + global_config: dict[str, str], ): node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) @@ -1623,6 +1638,7 @@ async def _find_most_related_edges_from_entities( } all_edges_data.append(combined) + tokenizer: Tokenizer = global_config["tokenizer"] all_edges_data = sorted( all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True ) @@ -1630,6 +1646,7 @@ async def _find_most_related_edges_from_entities( all_edges_data, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_global_context, + tokenizer=tokenizer, ) logger.debug( @@ -1645,6 +1662,7 @@ async def _get_edge_data( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, + global_config: dict[str, str], ): logger.info( f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" @@ -1685,6 +1703,7 @@ async def _get_edge_data( } edge_datas.append(combined) + tokenizer: Tokenizer = global_config["tokenizer"] edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) @@ -1692,13 +1711,14 @@ async def _get_edge_data( edge_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_global_context, + tokenizer=tokenizer, ) use_entities, use_text_units = await asyncio.gather( _find_most_related_entities_from_relationships( - edge_datas, query_param, knowledge_graph_inst + edge_datas, query_param, knowledge_graph_inst, global_config ), _find_related_text_unit_from_relationships( - edge_datas, query_param, text_chunks_db, knowledge_graph_inst + edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config ), ) logger.info( @@ -1778,6 +1798,7 @@ async def _find_most_related_entities_from_relationships( edge_datas: list[dict], query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, + global_config: dict[str, str], ): entity_names = [] seen = set() @@ -1808,11 +1829,13 @@ async def _find_most_related_entities_from_relationships( combined = {**node, "entity_name": entity_name, "rank": degree} node_datas.append(combined) + tokenizer: Tokenizer = global_config["tokenizer"] len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"] if x["description"] is not None else "", max_token_size=query_param.max_token_for_local_context, + tokenizer=tokenizer, ) logger.debug( f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})" @@ -1826,6 +1849,7 @@ async def _find_related_text_unit_from_relationships( query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, + global_config: dict[str, str], ): text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) @@ -1867,10 +1891,12 @@ async def _find_related_text_unit_from_relationships( logger.warning("No valid text chunks after filtering") return [] + tokenizer: Tokenizer = global_config["tokenizer"] truncated_text_units = truncate_list_by_token_size( valid_text_units, key=lambda x: x["data"]["content"], max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, ) logger.debug( @@ -1941,10 +1967,12 @@ async def naive_query( logger.warning("No valid chunks found after filtering") return PROMPTS["fail_response"] + tokenizer: Tokenizer = global_config["tokenizer"] maybe_trun_chunks = truncate_list_by_token_size( valid_chunks, key=lambda x: x["content"], max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, ) if not maybe_trun_chunks: @@ -1982,7 +2010,6 @@ async def naive_query( if query_param.only_need_prompt: return sys_prompt - tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}") @@ -2101,6 +2128,7 @@ async def kg_query_with_keywords( relationships_vdb, text_chunks_db, query_param, + global_config, ) if not context: return PROMPTS["fail_response"] diff --git a/lightrag/utils.py b/lightrag/utils.py index 3aded98a..0d490612 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -424,7 +424,7 @@ def is_float_regex(value: str) -> bool: def truncate_list_by_token_size( - list_data: list[Any], key: Callable[[Any], str], max_token_size: int + list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer ) -> list[int]: """Truncate a list of data by token size""" if max_token_size <= 0: